mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-11 17:30:44 +01:00
feat: percent-decode incoming path before routing
This commit is contained in:
parent
c18cb846d7
commit
0feb657818
9 changed files with 310 additions and 91 deletions
|
@ -391,4 +391,25 @@ mod tests {
|
|||
let res = client.get("/foo").await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn matching_braces() {
|
||||
let app = Router::new().route(
|
||||
// Double braces are interpreted by matchit as single literal brace
|
||||
"/{{foo}}",
|
||||
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/{foo}").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await, "/{{foo}}");
|
||||
|
||||
let res = client.get("/foo").await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
let res = client.get("/{{foo}}").await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -262,4 +262,22 @@ mod tests {
|
|||
let res = client.get("/api/users").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn nesting_with_braces() {
|
||||
let api = Router::new().route(
|
||||
"/users",
|
||||
get(|nested_path: NestedPath| {
|
||||
assert_eq!(nested_path.as_str(), "/{{api}}");
|
||||
async {}
|
||||
}),
|
||||
);
|
||||
|
||||
let app = Router::new().nest("/{{api}}", api);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/{api}/users").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use super::{ErrorKind, PathDeserializationError};
|
||||
use crate::util::PercentDecodedStr;
|
||||
use serde::{
|
||||
de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor},
|
||||
forward_to_deserialize_any, Deserializer,
|
||||
|
@ -33,7 +32,7 @@ macro_rules! parse_single_value {
|
|||
|
||||
let value = self.url_params[0].1.parse().map_err(|_| {
|
||||
PathDeserializationError::new(ErrorKind::ParseError {
|
||||
value: self.url_params[0].1.as_str().to_owned(),
|
||||
value: self.url_params[0].1.as_ref().to_owned(),
|
||||
expected_type: $ty,
|
||||
})
|
||||
})?;
|
||||
|
@ -43,12 +42,12 @@ macro_rules! parse_single_value {
|
|||
}
|
||||
|
||||
pub(crate) struct PathDeserializer<'de> {
|
||||
url_params: &'de [(Arc<str>, PercentDecodedStr)],
|
||||
url_params: &'de [(Arc<str>, Arc<str>)],
|
||||
}
|
||||
|
||||
impl<'de> PathDeserializer<'de> {
|
||||
#[inline]
|
||||
pub(crate) fn new(url_params: &'de [(Arc<str>, PercentDecodedStr)]) -> Self {
|
||||
pub(crate) fn new(url_params: &'de [(Arc<str>, Arc<str>)]) -> Self {
|
||||
PathDeserializer { url_params }
|
||||
}
|
||||
}
|
||||
|
@ -216,9 +215,9 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
|
|||
}
|
||||
|
||||
struct MapDeserializer<'de> {
|
||||
params: &'de [(Arc<str>, PercentDecodedStr)],
|
||||
params: &'de [(Arc<str>, Arc<str>)],
|
||||
key: Option<KeyOrIdx<'de>>,
|
||||
value: Option<&'de PercentDecodedStr>,
|
||||
value: Option<&'de Arc<str>>,
|
||||
}
|
||||
|
||||
impl<'de> MapAccess<'de> for MapDeserializer<'de> {
|
||||
|
@ -300,19 +299,19 @@ macro_rules! parse_value {
|
|||
let kind = match key {
|
||||
KeyOrIdx::Key(key) => ErrorKind::ParseErrorAtKey {
|
||||
key: key.to_owned(),
|
||||
value: self.value.as_str().to_owned(),
|
||||
value: self.value.as_ref().to_owned(),
|
||||
expected_type: $ty,
|
||||
},
|
||||
KeyOrIdx::Idx { idx: index, key: _ } => ErrorKind::ParseErrorAtIndex {
|
||||
index,
|
||||
value: self.value.as_str().to_owned(),
|
||||
value: self.value.as_ref().to_owned(),
|
||||
expected_type: $ty,
|
||||
},
|
||||
};
|
||||
PathDeserializationError::new(kind)
|
||||
} else {
|
||||
PathDeserializationError::new(ErrorKind::ParseError {
|
||||
value: self.value.as_str().to_owned(),
|
||||
value: self.value.as_ref().to_owned(),
|
||||
expected_type: $ty,
|
||||
})
|
||||
}
|
||||
|
@ -325,7 +324,7 @@ macro_rules! parse_value {
|
|||
#[derive(Debug)]
|
||||
struct ValueDeserializer<'de> {
|
||||
key: Option<KeyOrIdx<'de>>,
|
||||
value: &'de PercentDecodedStr,
|
||||
value: &'de Arc<str>,
|
||||
}
|
||||
|
||||
impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
|
||||
|
@ -414,7 +413,7 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
|
|||
{
|
||||
struct PairDeserializer<'de> {
|
||||
key: Option<KeyOrIdx<'de>>,
|
||||
value: Option<&'de PercentDecodedStr>,
|
||||
value: Option<&'de Arc<str>>,
|
||||
}
|
||||
|
||||
impl<'de> SeqAccess<'de> for PairDeserializer<'de> {
|
||||
|
@ -576,7 +575,7 @@ impl<'de> VariantAccess<'de> for UnitVariant {
|
|||
}
|
||||
|
||||
struct SeqDeserializer<'de> {
|
||||
params: &'de [(Arc<str>, PercentDecodedStr)],
|
||||
params: &'de [(Arc<str>, Arc<str>)],
|
||||
idx: usize,
|
||||
}
|
||||
|
||||
|
@ -629,7 +628,7 @@ mod tests {
|
|||
a: i32,
|
||||
}
|
||||
|
||||
fn create_url_params<I, K, V>(values: I) -> Vec<(Arc<str>, PercentDecodedStr)>
|
||||
fn create_url_params<I, K, V>(values: I) -> Vec<(Arc<str>, Arc<str>)>
|
||||
where
|
||||
I: IntoIterator<Item = (K, V)>,
|
||||
K: AsRef<str>,
|
||||
|
@ -637,7 +636,7 @@ mod tests {
|
|||
{
|
||||
values
|
||||
.into_iter()
|
||||
.map(|(k, v)| (Arc::from(k.as_ref()), PercentDecodedStr::new(v).unwrap()))
|
||||
.map(|(k, v)| (Arc::from(k.as_ref()), Arc::from(v.as_ref())))
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
@ -669,9 +668,10 @@ mod tests {
|
|||
check_single_value!(f32, "123", 123.0);
|
||||
check_single_value!(f64, "123", 123.0);
|
||||
check_single_value!(String, "abc", "abc");
|
||||
check_single_value!(String, "one%20two", "one two");
|
||||
check_single_value!(String, "one%20two", "one%20two");
|
||||
check_single_value!(String, "one two", "one two");
|
||||
check_single_value!(&str, "abc", "abc");
|
||||
check_single_value!(&str, "one%20two", "one two");
|
||||
check_single_value!(&str, "one two", "one two");
|
||||
check_single_value!(char, "a", 'a');
|
||||
|
||||
let url_params = create_url_params(vec![("a", "B")]);
|
||||
|
|
|
@ -6,7 +6,6 @@ mod de;
|
|||
use crate::{
|
||||
extract::{rejection::*, FromRequestParts},
|
||||
routing::url_params::UrlParams,
|
||||
util::PercentDecodedStr,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use axum_core::response::{IntoResponse, Response};
|
||||
|
@ -156,15 +155,6 @@ where
|
|||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let params = match parts.extensions.get::<UrlParams>() {
|
||||
Some(UrlParams::Params(params)) => params,
|
||||
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
|
||||
let err = PathDeserializationError {
|
||||
kind: ErrorKind::InvalidUtf8InPathParam {
|
||||
key: key.to_string(),
|
||||
},
|
||||
};
|
||||
let err = FailedToDeserializePathParams(err);
|
||||
return Err(err.into());
|
||||
}
|
||||
None => {
|
||||
return Err(MissingPathParams.into());
|
||||
}
|
||||
|
@ -444,7 +434,7 @@ impl std::error::Error for FailedToDeserializePathParams {}
|
|||
/// # let _: Router = app;
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct RawPathParams(Vec<(Arc<str>, PercentDecodedStr)>);
|
||||
pub struct RawPathParams(Vec<(Arc<str>, Arc<str>)>);
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for RawPathParams
|
||||
|
@ -456,12 +446,6 @@ where
|
|||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let params = match parts.extensions.get::<UrlParams>() {
|
||||
Some(UrlParams::Params(params)) => params,
|
||||
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
|
||||
return Err(InvalidUtf8InPathParam {
|
||||
key: Arc::clone(key),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
None => {
|
||||
return Err(MissingPathParams.into());
|
||||
}
|
||||
|
@ -491,14 +475,14 @@ impl<'a> IntoIterator for &'a RawPathParams {
|
|||
///
|
||||
/// Created with [`RawPathParams::iter`].
|
||||
#[derive(Debug)]
|
||||
pub struct RawPathParamsIter<'a>(std::slice::Iter<'a, (Arc<str>, PercentDecodedStr)>);
|
||||
pub struct RawPathParamsIter<'a>(std::slice::Iter<'a, (Arc<str>, Arc<str>)>);
|
||||
|
||||
impl<'a> Iterator for RawPathParamsIter<'a> {
|
||||
type Item = (&'a str, &'a str);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let (key, value) = self.0.next()?;
|
||||
Some((&**key, value.as_str()))
|
||||
Some((&**key, &**value))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -890,4 +874,61 @@ mod tests {
|
|||
let body = res.text().await;
|
||||
assert_eq!(body, "a=foo b=bar c=baz");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn percent_encoding_path() {
|
||||
let app = Router::new().route(
|
||||
"/{capture}",
|
||||
get(|Path(path): Path<String>| async move { path }),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/%61pi").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "api");
|
||||
|
||||
let res = client.get("/%2561pi").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "%61pi");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn percent_encoding_slash_in_path() {
|
||||
let app = Router::new().route(
|
||||
"/{capture}",
|
||||
get(|Path(path): Path<String>| async move { path })
|
||||
.fallback(|| async { panic!("not matched") }),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
// `%2f` decodes to `/`
|
||||
// Slashes are treated specially in the router
|
||||
let res = client.get("/%2flash").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "/lash");
|
||||
|
||||
let res = client.get("/%2Flash").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "/lash");
|
||||
|
||||
// TODO FIXME
|
||||
// This is not the correct behavior but should be so exceedingly rare that we can live with this for now.
|
||||
let res = client.get("/%252flash").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
// Should be
|
||||
// assert_eq!(body, "%2flash");
|
||||
assert_eq!(body, "/lash");
|
||||
|
||||
let res = client.get("/%25252flash").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "%252flash");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -363,9 +363,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
let path = req.uri().path().to_owned();
|
||||
// Double encode any percent-encoded `/`s so that they're not
|
||||
// interpreted by matchit. Additionally, percent-encode `%`s so that we
|
||||
// can differentiate between `%2f` we have encoded to `%252f` and
|
||||
// `%252f` the user might have sent us.
|
||||
let path = req
|
||||
.uri()
|
||||
.path()
|
||||
.replace("%2f", "%252f")
|
||||
.replace("%2F", "%252F");
|
||||
let decode = percent_encoding::percent_decode_str(&path);
|
||||
|
||||
match self.node.at(&path) {
|
||||
match self.node.at(&decode
|
||||
.decode_utf8()
|
||||
.unwrap_or(Cow::Owned(req.uri().path().to_owned())))
|
||||
{
|
||||
Ok(match_) => {
|
||||
let id = *match_.value;
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use http::{Request, Uri};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
@ -60,13 +61,13 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
|
|||
// path = /api/v0/users
|
||||
// ^^^^^^^ this much is matched and the length is 7.
|
||||
let mut matching_prefix_length = Some(0);
|
||||
for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
|
||||
for item in zip_longest(segments(path_and_query.path()), unescaped_segments(prefix)) {
|
||||
// count the `/`
|
||||
*matching_prefix_length.as_mut().unwrap() += 1;
|
||||
|
||||
match item {
|
||||
Item::Both(path_segment, prefix_segment) => {
|
||||
if is_capture(prefix_segment) || path_segment == prefix_segment {
|
||||
if is_capture(&prefix_segment) || path_segment == prefix_segment {
|
||||
// the prefix segment is either a param, which matches anything, or
|
||||
// it actually matches the path segment
|
||||
*matching_prefix_length.as_mut().unwrap() += path_segment.len();
|
||||
|
@ -121,7 +122,7 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
|
|||
Some(Uri::from_parts(parts).unwrap())
|
||||
}
|
||||
|
||||
fn segments(s: &str) -> impl Iterator<Item = &str> {
|
||||
fn segments(s: &str) -> impl Iterator<Item = Cow<'_, str>> {
|
||||
assert!(
|
||||
s.starts_with('/'),
|
||||
"path didn't start with '/'. axum should have caught this higher up."
|
||||
|
@ -131,6 +132,19 @@ fn segments(s: &str) -> impl Iterator<Item = &str> {
|
|||
// skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"]
|
||||
// otherwise
|
||||
.skip(1)
|
||||
.map(Cow::Borrowed)
|
||||
}
|
||||
|
||||
/// This unescapes anything handled specially by `matchit`.
|
||||
/// Currently, that means only `{{` and `}}` to mean literal `{` and `}` respectively.
|
||||
fn unescaped_segments(s: &str) -> impl Iterator<Item = Cow<'_, str>> {
|
||||
segments(s).map(|segment| {
|
||||
if segment.contains("{{") || segment.contains("}}") {
|
||||
Cow::Owned(segment.replace("{{", "{").replace("}}", "}"))
|
||||
} else {
|
||||
segment
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
|
||||
|
@ -380,6 +394,48 @@ mod tests {
|
|||
expected = Some("/a"),
|
||||
);
|
||||
|
||||
test!(
|
||||
braces_1,
|
||||
uri = "/{a}/a",
|
||||
prefix = "/{{a}}/",
|
||||
expected = Some("/a"),
|
||||
);
|
||||
|
||||
test!(
|
||||
braces_2,
|
||||
uri = "/{a}/b",
|
||||
prefix = "/{param}",
|
||||
expected = Some("/b"),
|
||||
);
|
||||
|
||||
test!(
|
||||
braces_3,
|
||||
uri = "/{a}/{b}",
|
||||
prefix = "/{{a}}/{{b}}",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(
|
||||
braces_4,
|
||||
uri = "/{a}/{b}",
|
||||
prefix = "/{{a}}/{b}",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(
|
||||
braces_5,
|
||||
uri = "/a/{b}",
|
||||
prefix = "/a",
|
||||
expected = Some("/{b}"),
|
||||
);
|
||||
|
||||
test!(
|
||||
braces_6,
|
||||
uri = "/a/{b}",
|
||||
prefix = "/{a}/{{b}}",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
#[quickcheck]
|
||||
fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
|
||||
let UriAndPrefix { uri, prefix } = uri_and_prefix;
|
||||
|
|
|
@ -1118,3 +1118,117 @@ async fn colon_in_route() {
|
|||
async fn asterisk_in_route() {
|
||||
_ = Router::<()>::new().route("/*foo", get(|| async move {}));
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn colon_in_route_allowed() {
|
||||
let app = Router::<()>::new()
|
||||
.without_v07_checks()
|
||||
.route("/:foo", get(|| async move {}));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/:foo").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client.get("/foo").await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn asterisk_in_route_allowed() {
|
||||
let app = Router::<()>::new()
|
||||
.without_v07_checks()
|
||||
.route("/*foo", get(|| async move {}));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/*foo").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client.get("/foo").await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn percent_encoding() {
|
||||
let app = Router::new().route("/api", get(|| async { "api" }));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/%61pi").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "api");
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn percent_encoding_slash() {
|
||||
let app = Router::new()
|
||||
.route("/slash/%2flash", get(|| async { "lower" }))
|
||||
.route("/slash/%2Flash", get(|| async { "upper" }))
|
||||
.route("/slash//lash", get(|| async { "/" }))
|
||||
.route("/api/user", get(|| async { "user" }))
|
||||
.route(
|
||||
"/{capture}",
|
||||
get(|Path(capture): Path<String>| {
|
||||
assert_eq!(capture, "api/user");
|
||||
ready("capture")
|
||||
}),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
// %2f encodes `/`
|
||||
let res = client.get("/api%2fuser").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "capture");
|
||||
|
||||
let res = client.get("/slash/%2flash").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "lower");
|
||||
|
||||
let res = client.get("/slash/%2Flash").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "upper");
|
||||
|
||||
// `%25` encodes `%`
|
||||
// This must not be decoded twice
|
||||
let res = client.get("/slash/%252flash").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "lower");
|
||||
|
||||
let res = client.get("/slash/%252Flash").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "upper");
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn percent_encoding_percent() {
|
||||
let app = Router::new()
|
||||
.route("/%61pi", get(|| async { "percent" }))
|
||||
.route("/api", get(|| async { "api" }));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/api").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "api");
|
||||
|
||||
let res = client.get("/%61pi").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "api");
|
||||
|
||||
// `%25` encodes `%`
|
||||
// This must not be decoded twice, otherwise it will become `/api`
|
||||
let res = client.get("/%2561pi").await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = res.text().await;
|
||||
assert_eq!(body, "percent");
|
||||
}
|
||||
|
|
|
@ -1,46 +1,32 @@
|
|||
use crate::util::PercentDecodedStr;
|
||||
use http::Extensions;
|
||||
use matchit::Params;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum UrlParams {
|
||||
Params(Vec<(Arc<str>, PercentDecodedStr)>),
|
||||
InvalidUtf8InPathParam { key: Arc<str> },
|
||||
Params(Vec<(Arc<str>, Arc<str>)>),
|
||||
}
|
||||
|
||||
pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) {
|
||||
let current_params = extensions.get_mut();
|
||||
|
||||
if let Some(UrlParams::InvalidUtf8InPathParam { .. }) = current_params {
|
||||
// nothing to do here since an error was stored earlier
|
||||
return;
|
||||
}
|
||||
|
||||
let params = params
|
||||
.iter()
|
||||
.filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM))
|
||||
.filter(|(key, _)| !key.starts_with(super::FALLBACK_PARAM))
|
||||
.map(|(k, v)| {
|
||||
if let Some(decoded) = PercentDecodedStr::new(v) {
|
||||
Ok((Arc::from(k), decoded))
|
||||
} else {
|
||||
Err(Arc::from(k))
|
||||
}
|
||||
(
|
||||
Arc::from(k),
|
||||
Arc::from(v.replace("%2f", "/").replace("%2F", "/")),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
match (current_params, params) {
|
||||
(Some(UrlParams::InvalidUtf8InPathParam { .. }), _) => {
|
||||
unreachable!("we check for this state earlier in this method")
|
||||
}
|
||||
(_, Err(invalid_key)) => {
|
||||
extensions.insert(UrlParams::InvalidUtf8InPathParam { key: invalid_key });
|
||||
}
|
||||
(Some(UrlParams::Params(current)), Ok(params)) => {
|
||||
(Some(UrlParams::Params(current)), params) => {
|
||||
current.extend(params);
|
||||
}
|
||||
(None, Ok(params)) => {
|
||||
(None, params) => {
|
||||
extensions.insert(UrlParams::Params(params));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,36 +1,7 @@
|
|||
use pin_project_lite::pin_project;
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
pub(crate) use self::mutex::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct PercentDecodedStr(Arc<str>);
|
||||
|
||||
impl PercentDecodedStr {
|
||||
pub(crate) fn new<S>(s: S) -> Option<Self>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
{
|
||||
percent_encoding::percent_decode(s.as_ref().as_bytes())
|
||||
.decode_utf8()
|
||||
.ok()
|
||||
.map(|decoded| Self(decoded.as_ref().into()))
|
||||
}
|
||||
|
||||
pub(crate) fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for PercentDecodedStr {
|
||||
type Target = str;
|
||||
|
||||
#[inline]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
#[project = EitherProj]
|
||||
pub(crate) enum Either<A, B> {
|
||||
|
|
Loading…
Reference in a new issue