mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-11 12:31:25 +01:00
Percent decode automatically in extract::Path
(#272)
* Percent decode automatically in `extract::Path` Fixes https://github.com/tokio-rs/axum/issues/261 * return an error if path param contains invalid utf-8 * Mention automatic decoding in the docs * Update changelog: This is a breaking change * cleanup * fix tests
This commit is contained in:
parent
2c2bcd7754
commit
afabded385
7 changed files with 156 additions and 58 deletions
|
@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
# Unreleased
|
||||
|
||||
- Improve performance of `BoxRoute` ([#339])
|
||||
- **breaking:** Automatically do percent decoding in `extract::Path`
|
||||
([#272])
|
||||
- **breaking:** `Router::boxed` now the inner service to implement `Clone` and
|
||||
`Sync` in addition to the previous trait bounds ([#339])
|
||||
- **breaking:** Added feature flags for HTTP1 and JSON. This enables removing a
|
||||
|
@ -16,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
[#339]: https://github.com/tokio-rs/axum/pull/339
|
||||
[#286]: https://github.com/tokio-rs/axum/pull/286
|
||||
[#272]: https://github.com/tokio-rs/axum/pull/272
|
||||
|
||||
# 0.2.6 (02. October, 2021)
|
||||
|
||||
|
|
25
Cargo.toml
25
Cargo.toml
|
@ -27,30 +27,32 @@ ws = ["tokio-tungstenite", "sha-1", "base64"]
|
|||
async-trait = "0.1"
|
||||
bitflags = "1.0"
|
||||
bytes = "1.0"
|
||||
dyn-clone = "1.0"
|
||||
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
|
||||
http = "0.2"
|
||||
http-body = "0.4.3"
|
||||
hyper = { version = "0.14", features = ["server", "tcp", "stream"] }
|
||||
percent-encoding = "2.1"
|
||||
pin-project-lite = "0.2.7"
|
||||
regex = "1.5"
|
||||
serde = "1.0"
|
||||
serde_json = { version = "1.0", optional = true }
|
||||
serde_urlencoded = "0.7"
|
||||
sync_wrapper = "0.1.1"
|
||||
tokio = { version = "1", features = ["time"] }
|
||||
tokio-util = "0.6"
|
||||
tower = { version = "0.4", default-features = false, features = ["util", "buffer", "make"] }
|
||||
tower-service = "0.3"
|
||||
tower-layer = "0.3"
|
||||
tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] }
|
||||
sync_wrapper = "0.1.1"
|
||||
tower-layer = "0.3"
|
||||
tower-service = "0.3"
|
||||
|
||||
# optional dependencies
|
||||
tokio-tungstenite = { optional = true, version = "0.15" }
|
||||
sha-1 = { optional = true, version = "0.9.6" }
|
||||
base64 = { optional = true, version = "0.13" }
|
||||
headers = { optional = true, version = "0.3" }
|
||||
multer = { optional = true, version = "2.0.0" }
|
||||
mime = { optional = true, version = "0.3" }
|
||||
multer = { optional = true, version = "2.0.0" }
|
||||
serde_json = { version = "1.0", optional = true }
|
||||
sha-1 = { optional = true, version = "0.9.6" }
|
||||
tokio-tungstenite = { optional = true, version = "0.15" }
|
||||
|
||||
[dev-dependencies]
|
||||
futures = "0.3"
|
||||
|
@ -82,4 +84,11 @@ all-features = true
|
|||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
|
||||
[package.metadata.playground]
|
||||
features = ["ws", "multipart", "headers"]
|
||||
features = [
|
||||
"http1",
|
||||
"http2",
|
||||
"json",
|
||||
"multipart",
|
||||
"tower",
|
||||
"ws",
|
||||
]
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use crate::routing::UrlParams;
|
||||
use crate::util::ByteStr;
|
||||
use crate::util::{ByteStr, PercentDecodedByteStr};
|
||||
use serde::{
|
||||
de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor},
|
||||
forward_to_deserialize_any, Deserializer,
|
||||
|
@ -53,20 +52,20 @@ macro_rules! parse_single_value {
|
|||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if self.url_params.0.len() != 1 {
|
||||
if self.url_params.len() != 1 {
|
||||
return Err(PathDeserializerError::custom(
|
||||
format!(
|
||||
"wrong number of parameters: {} expected 1",
|
||||
self.url_params.0.len()
|
||||
self.url_params.len()
|
||||
)
|
||||
.as_str(),
|
||||
));
|
||||
}
|
||||
|
||||
let value = self.url_params.0[0].1.parse().map_err(|_| {
|
||||
let value = self.url_params[0].1.parse().map_err(|_| {
|
||||
PathDeserializerError::custom(format!(
|
||||
"can not parse `{:?}` to a `{}`",
|
||||
self.url_params.0[0].1.as_str(),
|
||||
self.url_params[0].1.as_str(),
|
||||
$tp
|
||||
))
|
||||
})?;
|
||||
|
@ -76,12 +75,12 @@ macro_rules! parse_single_value {
|
|||
}
|
||||
|
||||
pub(crate) struct PathDeserializer<'de> {
|
||||
url_params: &'de UrlParams,
|
||||
url_params: &'de [(ByteStr, PercentDecodedByteStr)],
|
||||
}
|
||||
|
||||
impl<'de> PathDeserializer<'de> {
|
||||
#[inline]
|
||||
pub(crate) fn new(url_params: &'de UrlParams) -> Self {
|
||||
pub(crate) fn new(url_params: &'de [(ByteStr, PercentDecodedByteStr)]) -> Self {
|
||||
PathDeserializer { url_params }
|
||||
}
|
||||
}
|
||||
|
@ -114,13 +113,13 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
|
|||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if self.url_params.0.len() != 1 {
|
||||
if self.url_params.len() != 1 {
|
||||
return Err(PathDeserializerError::custom(format!(
|
||||
"wrong number of parameters: {} expected 1",
|
||||
self.url_params.0.len()
|
||||
self.url_params.len()
|
||||
)));
|
||||
}
|
||||
visitor.visit_str(&self.url_params.0[0].1)
|
||||
visitor.visit_str(&self.url_params[0].1)
|
||||
}
|
||||
|
||||
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
|
@ -157,7 +156,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
|
|||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_seq(SeqDeserializer {
|
||||
params: &self.url_params.0,
|
||||
params: self.url_params,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -165,18 +164,18 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
|
|||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if self.url_params.0.len() < len {
|
||||
if self.url_params.len() < len {
|
||||
return Err(PathDeserializerError::custom(
|
||||
format!(
|
||||
"wrong number of parameters: {} expected {}",
|
||||
self.url_params.0.len(),
|
||||
self.url_params.len(),
|
||||
len
|
||||
)
|
||||
.as_str(),
|
||||
));
|
||||
}
|
||||
visitor.visit_seq(SeqDeserializer {
|
||||
params: &self.url_params.0,
|
||||
params: self.url_params,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -189,18 +188,18 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
|
|||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if self.url_params.0.len() < len {
|
||||
if self.url_params.len() < len {
|
||||
return Err(PathDeserializerError::custom(
|
||||
format!(
|
||||
"wrong number of parameters: {} expected {}",
|
||||
self.url_params.0.len(),
|
||||
self.url_params.len(),
|
||||
len
|
||||
)
|
||||
.as_str(),
|
||||
));
|
||||
}
|
||||
visitor.visit_seq(SeqDeserializer {
|
||||
params: &self.url_params.0,
|
||||
params: self.url_params,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -209,7 +208,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
|
|||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_map(MapDeserializer {
|
||||
params: &self.url_params.0,
|
||||
params: self.url_params,
|
||||
value: None,
|
||||
})
|
||||
}
|
||||
|
@ -235,21 +234,21 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
|
|||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if self.url_params.0.len() != 1 {
|
||||
if self.url_params.len() != 1 {
|
||||
return Err(PathDeserializerError::custom(format!(
|
||||
"wrong number of parameters: {} expected 1",
|
||||
self.url_params.0.len()
|
||||
self.url_params.len()
|
||||
)));
|
||||
}
|
||||
|
||||
visitor.visit_enum(EnumDeserializer {
|
||||
value: &self.url_params.0[0].1,
|
||||
value: &self.url_params[0].1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct MapDeserializer<'de> {
|
||||
params: &'de [(ByteStr, ByteStr)],
|
||||
params: &'de [(ByteStr, PercentDecodedByteStr)],
|
||||
value: Option<&'de str>,
|
||||
}
|
||||
|
||||
|
@ -519,7 +518,7 @@ impl<'de> VariantAccess<'de> for UnitVariant {
|
|||
}
|
||||
|
||||
struct SeqDeserializer<'de> {
|
||||
params: &'de [(ByteStr, ByteStr)],
|
||||
params: &'de [(ByteStr, PercentDecodedByteStr)],
|
||||
}
|
||||
|
||||
impl<'de> SeqAccess<'de> for SeqDeserializer<'de> {
|
||||
|
@ -561,18 +560,16 @@ mod tests {
|
|||
a: i32,
|
||||
}
|
||||
|
||||
fn create_url_params<I, K, V>(values: I) -> UrlParams
|
||||
fn create_url_params<I, K, V>(values: I) -> Vec<(ByteStr, PercentDecodedByteStr)>
|
||||
where
|
||||
I: IntoIterator<Item = (K, V)>,
|
||||
K: AsRef<str>,
|
||||
V: AsRef<str>,
|
||||
{
|
||||
UrlParams(
|
||||
values
|
||||
.into_iter()
|
||||
.map(|(k, v)| (ByteStr::new(k), ByteStr::new(v)))
|
||||
.collect(),
|
||||
)
|
||||
values
|
||||
.into_iter()
|
||||
.map(|(k, v)| (ByteStr::new(k), PercentDecodedByteStr::new(v).unwrap()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
macro_rules! check_single_value {
|
||||
|
@ -601,6 +598,7 @@ mod tests {
|
|||
check_single_value!(f32, "123", 123.0);
|
||||
check_single_value!(f64, "123", 123.0);
|
||||
check_single_value!(String, "abc", "abc");
|
||||
check_single_value!(String, "one%20two", "one two");
|
||||
check_single_value!(char, "a", 'a');
|
||||
|
||||
let url_params = create_url_params(vec![("a", "B")]);
|
||||
|
|
|
@ -1,14 +1,24 @@
|
|||
mod de;
|
||||
|
||||
use super::{rejection::*, FromRequest};
|
||||
use crate::{extract::RequestParts, routing::UrlParams};
|
||||
use crate::{
|
||||
extract::RequestParts,
|
||||
routing::{InvalidUtf8InPathParam, UrlParams},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
|
||||
/// Extractor that will get captures from the URL and parse them using
|
||||
/// [`serde`].
|
||||
///
|
||||
/// Any percent encoded parameters will be automatically decoded. The decoded
|
||||
/// parameters must be valid UTF-8, otherwise `Path` will fail and return a `400
|
||||
/// Bad Request` response.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
|
@ -140,20 +150,45 @@ where
|
|||
{
|
||||
type Rejection = PathParamsRejection;
|
||||
|
||||
#[allow(warnings)]
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
const EMPTY_URL_PARAMS: &UrlParams = &UrlParams(Vec::new());
|
||||
|
||||
let url_params = if let Some(params) = req
|
||||
let params = match req
|
||||
.extensions_mut()
|
||||
.and_then(|ext| ext.get::<Option<UrlParams>>())
|
||||
{
|
||||
params.as_ref().unwrap_or(EMPTY_URL_PARAMS)
|
||||
} else {
|
||||
return Err(MissingRouteParams.into());
|
||||
Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params),
|
||||
Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => {
|
||||
return Err(InvalidPathParam::new(key.as_str()).into())
|
||||
}
|
||||
Some(None) => Cow::Owned(Vec::new()),
|
||||
None => {
|
||||
return Err(MissingRouteParams.into());
|
||||
}
|
||||
};
|
||||
|
||||
T::deserialize(de::PathDeserializer::new(url_params))
|
||||
T::deserialize(de::PathDeserializer::new(&*params))
|
||||
.map_err(|err| PathParamsRejection::InvalidPathParam(InvalidPathParam::new(err.0)))
|
||||
.map(Path)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::*;
|
||||
use crate::{handler::get, Router};
|
||||
|
||||
#[tokio::test]
|
||||
async fn percent_decoding() {
|
||||
let app = Router::new().route(
|
||||
"/:key",
|
||||
get(|Path(param): Path<String>| async move { param }),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/one%20two").send().await;
|
||||
|
||||
assert_eq!(res.text().await, "one two");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -107,7 +107,7 @@ define_rejection! {
|
|||
/// Rejection type for [`Path`](super::Path) if the capture route
|
||||
/// param didn't have the expected type.
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidPathParam(String);
|
||||
pub struct InvalidPathParam(pub(crate) String);
|
||||
|
||||
impl InvalidPathParam {
|
||||
pub(super) fn new(err: impl Into<String>) -> Self {
|
||||
|
|
|
@ -9,7 +9,7 @@ use crate::{
|
|||
OriginalUri,
|
||||
},
|
||||
service::HandleError,
|
||||
util::ByteStr,
|
||||
util::{ByteStr, PercentDecodedByteStr},
|
||||
BoxError,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
|
@ -627,24 +627,49 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UrlParams(pub(crate) Vec<(ByteStr, ByteStr)>);
|
||||
// we store the potential error here such that users can handle invalid path
|
||||
// params using `Result<Path<T>, _>`. That wouldn't be possible if we
|
||||
// returned an error immediately when decoding the param
|
||||
pub(crate) struct UrlParams(
|
||||
pub(crate) Result<Vec<(ByteStr, PercentDecodedByteStr)>, InvalidUtf8InPathParam>,
|
||||
);
|
||||
|
||||
fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
|
||||
let params = params
|
||||
.into_iter()
|
||||
.map(|(k, v)| (ByteStr::new(k), ByteStr::new(v)));
|
||||
.map(|(k, v)| {
|
||||
if let Some(decoded) = PercentDecodedByteStr::new(v) {
|
||||
Ok((ByteStr::new(k), decoded))
|
||||
} else {
|
||||
Err(InvalidUtf8InPathParam {
|
||||
key: ByteStr::new(k),
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
|
||||
if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
|
||||
let mut current = current.take().unwrap();
|
||||
current.0.extend(params);
|
||||
req.extensions_mut().insert(Some(current));
|
||||
match params {
|
||||
Ok(params) => {
|
||||
let mut current = current.take().unwrap();
|
||||
if let Ok(current) = &mut current.0 {
|
||||
current.extend(params);
|
||||
}
|
||||
req.extensions_mut().insert(Some(current));
|
||||
}
|
||||
Err(err) => {
|
||||
req.extensions_mut().insert(Some(UrlParams(Err(err))));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
req.extensions_mut()
|
||||
.insert(Some(UrlParams(params.collect())));
|
||||
req.extensions_mut().insert(Some(UrlParams(params)));
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct InvalidUtf8InPathParam {
|
||||
pub(crate) key: ByteStr,
|
||||
}
|
||||
|
||||
/// A [`Service`] that responds with `404 Not Found` or `405 Method not allowed`
|
||||
/// to all requests.
|
||||
///
|
||||
|
|
28
src/util.rs
28
src/util.rs
|
@ -30,6 +30,34 @@ impl ByteStr {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct PercentDecodedByteStr(ByteStr);
|
||||
|
||||
impl PercentDecodedByteStr {
|
||||
pub(crate) fn new<S>(s: S) -> Option<Self>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
{
|
||||
percent_encoding::percent_decode(s.as_ref().as_bytes())
|
||||
.decode_utf8()
|
||||
.ok()
|
||||
.map(|decoded| Self(ByteStr::new(decoded)))
|
||||
}
|
||||
|
||||
pub(crate) fn as_str(&self) -> &str {
|
||||
self.0.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PercentDecodedByteStr {
|
||||
type Target = str;
|
||||
|
||||
#[inline]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
#[project = EitherProj]
|
||||
pub(crate) enum Either<A, B> {
|
||||
|
|
Loading…
Reference in a new issue