Add OptionalPath extractor (#1889)

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Jonas Platte 2023-04-09 14:23:13 +02:00 committed by GitHub
parent 946d8c3253
commit 43b2d52403
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 124 additions and 13 deletions

View file

@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning].
# Unreleased
- None.
- **added:** Add `OptionalPath` extractor ([#1889])
[#1889]: https://github.com/tokio-rs/axum/pull/1889
# 0.7.2 (22. March, 2023)

View file

@ -19,11 +19,10 @@ cookie = ["dep:cookie"]
cookie-private = ["cookie", "cookie?/private"]
cookie-signed = ["cookie", "cookie?/signed"]
cookie-key-expansion = ["cookie", "cookie?/key-expansion"]
erased-json = ["dep:serde_json", "dep:serde"]
form = ["dep:serde", "dep:serde_html_form"]
erased-json = ["dep:serde_json"]
form = ["dep:serde_html_form"]
json-lines = [
"dep:serde_json",
"dep:serde",
"dep:tokio-util",
"dep:tokio-stream",
"tokio-util?/io",
@ -31,8 +30,8 @@ json-lines = [
]
multipart = ["dep:multer"]
protobuf = ["dep:prost"]
query = ["dep:serde", "dep:serde_html_form"]
typed-routing = ["dep:axum-macros", "dep:serde", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]
query = ["dep:serde_html_form"]
typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]
[dependencies]
axum = { path = "../axum", version = "0.6.9", default-features = false }
@ -42,6 +41,7 @@ http = "0.2"
http-body = "0.4.4"
mime = "0.3"
pin-project-lite = "0.2"
serde = "1.0"
tokio = "1.19"
tower = { version = "0.4", default_features = false, features = ["util"] }
tower-http = { version = "0.4", features = ["map-response-body"] }
@ -55,7 +55,6 @@ form_urlencoded = { version = "1.1.0", optional = true }
multer = { version = "2.0.0", optional = true }
percent-encoding = { version = "2.1", optional = true }
prost = { version = "0.11", optional = true }
serde = { version = "1.0", optional = true }
serde_html_form = { version = "0.2.0", optional = true }
serde_json = { version = "1.0.71", optional = true }
tokio-stream = { version = "0.1.9", optional = true }
@ -63,6 +62,7 @@ tokio-util = { version = "0.7", optional = true }
[dev-dependencies]
axum = { path = "../axum", version = "0.6.0", features = ["headers"] }
axum-macros = { path = "../axum-macros", version = "0.3.7", features = ["__private"] }
http-body = "0.4.4"
hyper = "0.14"
reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "multipart"] }

View file

@ -1,6 +1,8 @@
//! Additional extractors.
mod cached;
mod optional_path;
mod with_rejection;
#[cfg(feature = "form")]
mod form;
@ -14,9 +16,7 @@ mod query;
#[cfg(feature = "multipart")]
pub mod multipart;
mod with_rejection;
pub use self::cached::Cached;
pub use self::{cached::Cached, optional_path::OptionalPath, with_rejection::WithRejection};
#[cfg(feature = "cookie")]
pub use self::cookie::CookieJar;
@ -39,5 +39,3 @@ pub use self::multipart::Multipart;
#[cfg(feature = "json-lines")]
#[doc(no_inline)]
pub use crate::json_lines::JsonLines;
pub use self::with_rejection::WithRejection;

View file

@ -0,0 +1,102 @@
use axum::{
async_trait,
extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path},
RequestPartsExt,
};
use serde::de::DeserializeOwned;
/// Extractor that extracts path arguments the same way as [`Path`], except if there aren't any.
///
/// This extractor can be used in place of `Path` when you have two routes that you want to handle
/// in mostly the same way, where one has a path parameter and the other one doesn't.
///
/// # Example
///
/// ```
/// use std::num::NonZeroU32;
/// use axum::{
/// response::IntoResponse,
/// routing::get,
/// Router,
/// };
/// use axum_extra::extract::OptionalPath;
///
/// async fn render_blog(OptionalPath(page): OptionalPath<NonZeroU32>) -> impl IntoResponse {
/// // Convert to u32, default to page 1 if not specified
/// let page = page.map_or(1, |param| param.get());
/// // ...
/// }
///
/// let app = Router::new()
/// .route("/blog", get(render_blog))
/// .route("/blog/:page", get(render_blog));
/// # let app: Router = app;
/// ```
#[derive(Debug)]
pub struct OptionalPath<T>(pub Option<T>);
#[async_trait]
impl<T, S> FromRequestParts<S> for OptionalPath<T>
where
T: DeserializeOwned + Send + 'static,
S: Send + Sync,
{
type Rejection = PathRejection;
async fn from_request_parts(
parts: &mut http::request::Parts,
_: &S,
) -> Result<Self, Self::Rejection> {
match parts.extract::<Path<T>>().await {
Ok(Path(params)) => Ok(Self(Some(params))),
Err(PathRejection::FailedToDeserializePathParams(e))
if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) =>
{
Ok(Self(None))
}
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroU32;
use axum::{routing::get, Router};
use super::OptionalPath;
use crate::test_helpers::TestClient;
#[crate::test]
async fn supports_128_bit_numbers() {
async fn handle(OptionalPath(param): OptionalPath<NonZeroU32>) -> String {
let num = param.map_or(0, |p| p.get());
format!("Success: {num}")
}
let app = Router::new()
.route("/", get(handle))
.route("/:num", get(handle));
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.text().await, "Success: 0");
let res = client.get("/1").send().await;
assert_eq!(res.text().await, "Success: 1");
let res = client.get("/0").send().await;
assert_eq!(
res.text().await,
"Invalid URL: invalid value: integer `0`, expected a nonzero u32"
);
let res = client.get("/NaN").send().await;
assert_eq!(
res.text().await,
"Invalid URL: Cannot parse `\"NaN\"` to a `u32`"
);
}
}

View file

@ -97,6 +97,9 @@ pub mod __private {
pub const PATH_SEGMENT: &AsciiSet = &PATH.add(b'/').add(b'%');
}
#[cfg(test)]
use axum_macros::__private_axum_test as test;
#[cfg(test)]
pub(crate) mod test_helpers {
#![allow(unused_imports)]

View file

@ -266,7 +266,8 @@ impl std::error::Error for PathDeserializationError {}
/// The kinds of errors that can happen we deserializing into a [`Path`].
///
/// This type is obtained through [`FailedToDeserializePathParams::into_kind`] and is useful for building
/// This type is obtained through [`FailedToDeserializePathParams::kind`] or
/// [`FailedToDeserializePathParams::into_kind`] and is useful for building
/// more precise error messages.
#[derive(Debug, PartialEq, Eq)]
#[non_exhaustive]
@ -380,6 +381,11 @@ impl fmt::Display for ErrorKind {
pub struct FailedToDeserializePathParams(PathDeserializationError);
impl FailedToDeserializePathParams {
/// Get a reference to the underlying error kind.
pub fn kind(&self) -> &ErrorKind {
&self.0.kind
}
/// Convert this error into the underlying error kind.
pub fn into_kind(self) -> ErrorKind {
self.0.kind