Percent decode automatically in extract::Path (#272)

* Percent decode automatically in `extract::Path`

Fixes https://github.com/tokio-rs/axum/issues/261

* return an error if path param contains invalid utf-8

* Mention automatic decoding in the docs

* Update changelog: This is a breaking change

* cleanup

* fix tests
This commit is contained in:
David Pedersen 2021-10-02 16:04:29 +02:00 committed by GitHub
parent 2c2bcd7754
commit afabded385
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 156 additions and 58 deletions

View file

@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- Improve performance of `BoxRoute` ([#339])
- **breaking:** Automatically do percent decoding in `extract::Path`
([#272])
- **breaking:** `Router::boxed` now the inner service to implement `Clone` and
`Sync` in addition to the previous trait bounds ([#339])
- **breaking:** Added feature flags for HTTP1 and JSON. This enables removing a
@ -16,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#339]: https://github.com/tokio-rs/axum/pull/339
[#286]: https://github.com/tokio-rs/axum/pull/286
[#272]: https://github.com/tokio-rs/axum/pull/272
# 0.2.6 (02. October, 2021)

View file

@ -27,30 +27,32 @@ ws = ["tokio-tungstenite", "sha-1", "base64"]
async-trait = "0.1"
bitflags = "1.0"
bytes = "1.0"
dyn-clone = "1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "0.2"
http-body = "0.4.3"
hyper = { version = "0.14", features = ["server", "tcp", "stream"] }
percent-encoding = "2.1"
pin-project-lite = "0.2.7"
regex = "1.5"
serde = "1.0"
serde_json = { version = "1.0", optional = true }
serde_urlencoded = "0.7"
sync_wrapper = "0.1.1"
tokio = { version = "1", features = ["time"] }
tokio-util = "0.6"
tower = { version = "0.4", default-features = false, features = ["util", "buffer", "make"] }
tower-service = "0.3"
tower-layer = "0.3"
tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] }
sync_wrapper = "0.1.1"
tower-layer = "0.3"
tower-service = "0.3"
# optional dependencies
tokio-tungstenite = { optional = true, version = "0.15" }
sha-1 = { optional = true, version = "0.9.6" }
base64 = { optional = true, version = "0.13" }
headers = { optional = true, version = "0.3" }
multer = { optional = true, version = "2.0.0" }
mime = { optional = true, version = "0.3" }
multer = { optional = true, version = "2.0.0" }
serde_json = { version = "1.0", optional = true }
sha-1 = { optional = true, version = "0.9.6" }
tokio-tungstenite = { optional = true, version = "0.15" }
[dev-dependencies]
futures = "0.3"
@ -82,4 +84,11 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs"]
[package.metadata.playground]
features = ["ws", "multipart", "headers"]
features = [
"http1",
"http2",
"json",
"multipart",
"tower",
"ws",
]

View file

@ -1,5 +1,4 @@
use crate::routing::UrlParams;
use crate::util::ByteStr;
use crate::util::{ByteStr, PercentDecodedByteStr};
use serde::{
de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor},
forward_to_deserialize_any, Deserializer,
@ -53,20 +52,20 @@ macro_rules! parse_single_value {
where
V: Visitor<'de>,
{
if self.url_params.0.len() != 1 {
if self.url_params.len() != 1 {
return Err(PathDeserializerError::custom(
format!(
"wrong number of parameters: {} expected 1",
self.url_params.0.len()
self.url_params.len()
)
.as_str(),
));
}
let value = self.url_params.0[0].1.parse().map_err(|_| {
let value = self.url_params[0].1.parse().map_err(|_| {
PathDeserializerError::custom(format!(
"can not parse `{:?}` to a `{}`",
self.url_params.0[0].1.as_str(),
self.url_params[0].1.as_str(),
$tp
))
})?;
@ -76,12 +75,12 @@ macro_rules! parse_single_value {
}
pub(crate) struct PathDeserializer<'de> {
url_params: &'de UrlParams,
url_params: &'de [(ByteStr, PercentDecodedByteStr)],
}
impl<'de> PathDeserializer<'de> {
#[inline]
pub(crate) fn new(url_params: &'de UrlParams) -> Self {
pub(crate) fn new(url_params: &'de [(ByteStr, PercentDecodedByteStr)]) -> Self {
PathDeserializer { url_params }
}
}
@ -114,13 +113,13 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
where
V: Visitor<'de>,
{
if self.url_params.0.len() != 1 {
if self.url_params.len() != 1 {
return Err(PathDeserializerError::custom(format!(
"wrong number of parameters: {} expected 1",
self.url_params.0.len()
self.url_params.len()
)));
}
visitor.visit_str(&self.url_params.0[0].1)
visitor.visit_str(&self.url_params[0].1)
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
@ -157,7 +156,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
V: Visitor<'de>,
{
visitor.visit_seq(SeqDeserializer {
params: &self.url_params.0,
params: self.url_params,
})
}
@ -165,18 +164,18 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
where
V: Visitor<'de>,
{
if self.url_params.0.len() < len {
if self.url_params.len() < len {
return Err(PathDeserializerError::custom(
format!(
"wrong number of parameters: {} expected {}",
self.url_params.0.len(),
self.url_params.len(),
len
)
.as_str(),
));
}
visitor.visit_seq(SeqDeserializer {
params: &self.url_params.0,
params: self.url_params,
})
}
@ -189,18 +188,18 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
where
V: Visitor<'de>,
{
if self.url_params.0.len() < len {
if self.url_params.len() < len {
return Err(PathDeserializerError::custom(
format!(
"wrong number of parameters: {} expected {}",
self.url_params.0.len(),
self.url_params.len(),
len
)
.as_str(),
));
}
visitor.visit_seq(SeqDeserializer {
params: &self.url_params.0,
params: self.url_params,
})
}
@ -209,7 +208,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
V: Visitor<'de>,
{
visitor.visit_map(MapDeserializer {
params: &self.url_params.0,
params: self.url_params,
value: None,
})
}
@ -235,21 +234,21 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
where
V: Visitor<'de>,
{
if self.url_params.0.len() != 1 {
if self.url_params.len() != 1 {
return Err(PathDeserializerError::custom(format!(
"wrong number of parameters: {} expected 1",
self.url_params.0.len()
self.url_params.len()
)));
}
visitor.visit_enum(EnumDeserializer {
value: &self.url_params.0[0].1,
value: &self.url_params[0].1,
})
}
}
struct MapDeserializer<'de> {
params: &'de [(ByteStr, ByteStr)],
params: &'de [(ByteStr, PercentDecodedByteStr)],
value: Option<&'de str>,
}
@ -519,7 +518,7 @@ impl<'de> VariantAccess<'de> for UnitVariant {
}
struct SeqDeserializer<'de> {
params: &'de [(ByteStr, ByteStr)],
params: &'de [(ByteStr, PercentDecodedByteStr)],
}
impl<'de> SeqAccess<'de> for SeqDeserializer<'de> {
@ -561,18 +560,16 @@ mod tests {
a: i32,
}
fn create_url_params<I, K, V>(values: I) -> UrlParams
fn create_url_params<I, K, V>(values: I) -> Vec<(ByteStr, PercentDecodedByteStr)>
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
UrlParams(
values
.into_iter()
.map(|(k, v)| (ByteStr::new(k), ByteStr::new(v)))
.collect(),
)
values
.into_iter()
.map(|(k, v)| (ByteStr::new(k), PercentDecodedByteStr::new(v).unwrap()))
.collect()
}
macro_rules! check_single_value {
@ -601,6 +598,7 @@ mod tests {
check_single_value!(f32, "123", 123.0);
check_single_value!(f64, "123", 123.0);
check_single_value!(String, "abc", "abc");
check_single_value!(String, "one%20two", "one two");
check_single_value!(char, "a", 'a');
let url_params = create_url_params(vec![("a", "B")]);

View file

@ -1,14 +1,24 @@
mod de;
use super::{rejection::*, FromRequest};
use crate::{extract::RequestParts, routing::UrlParams};
use crate::{
extract::RequestParts,
routing::{InvalidUtf8InPathParam, UrlParams},
};
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use std::ops::{Deref, DerefMut};
use std::{
borrow::Cow,
ops::{Deref, DerefMut},
};
/// Extractor that will get captures from the URL and parse them using
/// [`serde`].
///
/// Any percent encoded parameters will be automatically decoded. The decoded
/// parameters must be valid UTF-8, otherwise `Path` will fail and return a `400
/// Bad Request` response.
///
/// # Example
///
/// ```rust,no_run
@ -140,20 +150,45 @@ where
{
type Rejection = PathParamsRejection;
#[allow(warnings)]
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
const EMPTY_URL_PARAMS: &UrlParams = &UrlParams(Vec::new());
let url_params = if let Some(params) = req
let params = match req
.extensions_mut()
.and_then(|ext| ext.get::<Option<UrlParams>>())
{
params.as_ref().unwrap_or(EMPTY_URL_PARAMS)
} else {
return Err(MissingRouteParams.into());
Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params),
Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => {
return Err(InvalidPathParam::new(key.as_str()).into())
}
Some(None) => Cow::Owned(Vec::new()),
None => {
return Err(MissingRouteParams.into());
}
};
T::deserialize(de::PathDeserializer::new(url_params))
T::deserialize(de::PathDeserializer::new(&*params))
.map_err(|err| PathParamsRejection::InvalidPathParam(InvalidPathParam::new(err.0)))
.map(Path)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;
use crate::{handler::get, Router};
#[tokio::test]
async fn percent_decoding() {
let app = Router::new().route(
"/:key",
get(|Path(param): Path<String>| async move { param }),
);
let client = TestClient::new(app);
let res = client.get("/one%20two").send().await;
assert_eq!(res.text().await, "one two");
}
}

View file

@ -107,7 +107,7 @@ define_rejection! {
/// Rejection type for [`Path`](super::Path) if the capture route
/// param didn't have the expected type.
#[derive(Debug)]
pub struct InvalidPathParam(String);
pub struct InvalidPathParam(pub(crate) String);
impl InvalidPathParam {
pub(super) fn new(err: impl Into<String>) -> Self {

View file

@ -9,7 +9,7 @@ use crate::{
OriginalUri,
},
service::HandleError,
util::ByteStr,
util::{ByteStr, PercentDecodedByteStr},
BoxError,
};
use bytes::Bytes;
@ -627,24 +627,49 @@ where
}
}
#[derive(Debug)]
pub(crate) struct UrlParams(pub(crate) Vec<(ByteStr, ByteStr)>);
// we store the potential error here such that users can handle invalid path
// params using `Result<Path<T>, _>`. That wouldn't be possible if we
// returned an error immediately when decoding the param
pub(crate) struct UrlParams(
pub(crate) Result<Vec<(ByteStr, PercentDecodedByteStr)>, InvalidUtf8InPathParam>,
);
fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
let params = params
.into_iter()
.map(|(k, v)| (ByteStr::new(k), ByteStr::new(v)));
.map(|(k, v)| {
if let Some(decoded) = PercentDecodedByteStr::new(v) {
Ok((ByteStr::new(k), decoded))
} else {
Err(InvalidUtf8InPathParam {
key: ByteStr::new(k),
})
}
})
.collect::<Result<Vec<_>, _>>();
if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
let mut current = current.take().unwrap();
current.0.extend(params);
req.extensions_mut().insert(Some(current));
match params {
Ok(params) => {
let mut current = current.take().unwrap();
if let Ok(current) = &mut current.0 {
current.extend(params);
}
req.extensions_mut().insert(Some(current));
}
Err(err) => {
req.extensions_mut().insert(Some(UrlParams(Err(err))));
}
}
} else {
req.extensions_mut()
.insert(Some(UrlParams(params.collect())));
req.extensions_mut().insert(Some(UrlParams(params)));
}
}
pub(crate) struct InvalidUtf8InPathParam {
pub(crate) key: ByteStr,
}
/// A [`Service`] that responds with `404 Not Found` or `405 Method not allowed`
/// to all requests.
///

View file

@ -30,6 +30,34 @@ impl ByteStr {
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct PercentDecodedByteStr(ByteStr);
impl PercentDecodedByteStr {
pub(crate) fn new<S>(s: S) -> Option<Self>
where
S: AsRef<str>,
{
percent_encoding::percent_decode(s.as_ref().as_bytes())
.decode_utf8()
.ok()
.map(|decoded| Self(ByteStr::new(decoded)))
}
pub(crate) fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl Deref for PercentDecodedByteStr {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
pin_project! {
#[project = EitherProj]
pub(crate) enum Either<A, B> {