feat: percent-decode incoming path before routing

This commit is contained in:
David Mládek 2024-05-02 22:52:38 +02:00
parent c18cb846d7
commit 0feb657818
9 changed files with 310 additions and 91 deletions

View file

@ -391,4 +391,25 @@ mod tests {
let res = client.get("/foo").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[crate::test]
async fn matching_braces() {
let app = Router::new().route(
// Double braces are interpreted by matchit as single literal brace
"/{{foo}}",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
);
let client = TestClient::new(app);
let res = client.get("/{foo}").await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "/{{foo}}");
let res = client.get("/foo").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let res = client.get("/{{foo}}").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
}

View file

@ -262,4 +262,22 @@ mod tests {
let res = client.get("/api/users").await;
assert_eq!(res.status(), StatusCode::OK);
}
#[crate::test]
async fn nesting_with_braces() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/{{api}}");
async {}
}),
);
let app = Router::new().nest("/{{api}}", api);
let client = TestClient::new(app);
let res = client.get("/{api}/users").await;
assert_eq!(res.status(), StatusCode::OK);
}
}

View file

@ -1,5 +1,4 @@
use super::{ErrorKind, PathDeserializationError};
use crate::util::PercentDecodedStr;
use serde::{
de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor},
forward_to_deserialize_any, Deserializer,
@ -33,7 +32,7 @@ macro_rules! parse_single_value {
let value = self.url_params[0].1.parse().map_err(|_| {
PathDeserializationError::new(ErrorKind::ParseError {
value: self.url_params[0].1.as_str().to_owned(),
value: self.url_params[0].1.as_ref().to_owned(),
expected_type: $ty,
})
})?;
@ -43,12 +42,12 @@ macro_rules! parse_single_value {
}
pub(crate) struct PathDeserializer<'de> {
url_params: &'de [(Arc<str>, PercentDecodedStr)],
url_params: &'de [(Arc<str>, Arc<str>)],
}
impl<'de> PathDeserializer<'de> {
#[inline]
pub(crate) fn new(url_params: &'de [(Arc<str>, PercentDecodedStr)]) -> Self {
pub(crate) fn new(url_params: &'de [(Arc<str>, Arc<str>)]) -> Self {
PathDeserializer { url_params }
}
}
@ -216,9 +215,9 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
}
struct MapDeserializer<'de> {
params: &'de [(Arc<str>, PercentDecodedStr)],
params: &'de [(Arc<str>, Arc<str>)],
key: Option<KeyOrIdx<'de>>,
value: Option<&'de PercentDecodedStr>,
value: Option<&'de Arc<str>>,
}
impl<'de> MapAccess<'de> for MapDeserializer<'de> {
@ -300,19 +299,19 @@ macro_rules! parse_value {
let kind = match key {
KeyOrIdx::Key(key) => ErrorKind::ParseErrorAtKey {
key: key.to_owned(),
value: self.value.as_str().to_owned(),
value: self.value.as_ref().to_owned(),
expected_type: $ty,
},
KeyOrIdx::Idx { idx: index, key: _ } => ErrorKind::ParseErrorAtIndex {
index,
value: self.value.as_str().to_owned(),
value: self.value.as_ref().to_owned(),
expected_type: $ty,
},
};
PathDeserializationError::new(kind)
} else {
PathDeserializationError::new(ErrorKind::ParseError {
value: self.value.as_str().to_owned(),
value: self.value.as_ref().to_owned(),
expected_type: $ty,
})
}
@ -325,7 +324,7 @@ macro_rules! parse_value {
#[derive(Debug)]
struct ValueDeserializer<'de> {
key: Option<KeyOrIdx<'de>>,
value: &'de PercentDecodedStr,
value: &'de Arc<str>,
}
impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
@ -414,7 +413,7 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
{
struct PairDeserializer<'de> {
key: Option<KeyOrIdx<'de>>,
value: Option<&'de PercentDecodedStr>,
value: Option<&'de Arc<str>>,
}
impl<'de> SeqAccess<'de> for PairDeserializer<'de> {
@ -576,7 +575,7 @@ impl<'de> VariantAccess<'de> for UnitVariant {
}
struct SeqDeserializer<'de> {
params: &'de [(Arc<str>, PercentDecodedStr)],
params: &'de [(Arc<str>, Arc<str>)],
idx: usize,
}
@ -629,7 +628,7 @@ mod tests {
a: i32,
}
fn create_url_params<I, K, V>(values: I) -> Vec<(Arc<str>, PercentDecodedStr)>
fn create_url_params<I, K, V>(values: I) -> Vec<(Arc<str>, Arc<str>)>
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
@ -637,7 +636,7 @@ mod tests {
{
values
.into_iter()
.map(|(k, v)| (Arc::from(k.as_ref()), PercentDecodedStr::new(v).unwrap()))
.map(|(k, v)| (Arc::from(k.as_ref()), Arc::from(v.as_ref())))
.collect()
}
@ -669,9 +668,10 @@ mod tests {
check_single_value!(f32, "123", 123.0);
check_single_value!(f64, "123", 123.0);
check_single_value!(String, "abc", "abc");
check_single_value!(String, "one%20two", "one two");
check_single_value!(String, "one%20two", "one%20two");
check_single_value!(String, "one two", "one two");
check_single_value!(&str, "abc", "abc");
check_single_value!(&str, "one%20two", "one two");
check_single_value!(&str, "one two", "one two");
check_single_value!(char, "a", 'a');
let url_params = create_url_params(vec![("a", "B")]);

View file

@ -6,7 +6,6 @@ mod de;
use crate::{
extract::{rejection::*, FromRequestParts},
routing::url_params::UrlParams,
util::PercentDecodedStr,
};
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
@ -156,15 +155,6 @@ where
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let params = match parts.extensions.get::<UrlParams>() {
Some(UrlParams::Params(params)) => params,
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
let err = PathDeserializationError {
kind: ErrorKind::InvalidUtf8InPathParam {
key: key.to_string(),
},
};
let err = FailedToDeserializePathParams(err);
return Err(err.into());
}
None => {
return Err(MissingPathParams.into());
}
@ -444,7 +434,7 @@ impl std::error::Error for FailedToDeserializePathParams {}
/// # let _: Router = app;
/// ```
#[derive(Debug)]
pub struct RawPathParams(Vec<(Arc<str>, PercentDecodedStr)>);
pub struct RawPathParams(Vec<(Arc<str>, Arc<str>)>);
#[async_trait]
impl<S> FromRequestParts<S> for RawPathParams
@ -456,12 +446,6 @@ where
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let params = match parts.extensions.get::<UrlParams>() {
Some(UrlParams::Params(params)) => params,
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
return Err(InvalidUtf8InPathParam {
key: Arc::clone(key),
}
.into());
}
None => {
return Err(MissingPathParams.into());
}
@ -491,14 +475,14 @@ impl<'a> IntoIterator for &'a RawPathParams {
///
/// Created with [`RawPathParams::iter`].
#[derive(Debug)]
pub struct RawPathParamsIter<'a>(std::slice::Iter<'a, (Arc<str>, PercentDecodedStr)>);
pub struct RawPathParamsIter<'a>(std::slice::Iter<'a, (Arc<str>, Arc<str>)>);
impl<'a> Iterator for RawPathParamsIter<'a> {
type Item = (&'a str, &'a str);
fn next(&mut self) -> Option<Self::Item> {
let (key, value) = self.0.next()?;
Some((&**key, value.as_str()))
Some((&**key, &**value))
}
}
@ -890,4 +874,61 @@ mod tests {
let body = res.text().await;
assert_eq!(body, "a=foo b=bar c=baz");
}
#[tokio::test]
async fn percent_encoding_path() {
let app = Router::new().route(
"/{capture}",
get(|Path(path): Path<String>| async move { path }),
);
let client = TestClient::new(app);
let res = client.get("/%61pi").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "api");
let res = client.get("/%2561pi").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "%61pi");
}
#[tokio::test]
async fn percent_encoding_slash_in_path() {
let app = Router::new().route(
"/{capture}",
get(|Path(path): Path<String>| async move { path })
.fallback(|| async { panic!("not matched") }),
);
let client = TestClient::new(app);
// `%2f` decodes to `/`
// Slashes are treated specially in the router
let res = client.get("/%2flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "/lash");
let res = client.get("/%2Flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "/lash");
// TODO FIXME
// This is not the correct behavior but should be so exceedingly rare that we can live with this for now.
let res = client.get("/%252flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
// Should be
// assert_eq!(body, "%2flash");
assert_eq!(body, "/lash");
let res = client.get("/%25252flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "%252flash");
}
}

View file

@ -363,9 +363,21 @@ where
}
}
let path = req.uri().path().to_owned();
// Double encode any percent-encoded `/`s so that they're not
// interpreted by matchit. Additionally, percent-encode `%`s so that we
// can differentiate between `%2f` we have encoded to `%252f` and
// `%252f` the user might have sent us.
let path = req
.uri()
.path()
.replace("%2f", "%252f")
.replace("%2F", "%252F");
let decode = percent_encoding::percent_decode_str(&path);
match self.node.at(&path) {
match self.node.at(&decode
.decode_utf8()
.unwrap_or(Cow::Owned(req.uri().path().to_owned())))
{
Ok(match_) => {
let id = *match_.value;

View file

@ -1,5 +1,6 @@
use http::{Request, Uri};
use std::{
borrow::Cow,
sync::Arc,
task::{Context, Poll},
};
@ -60,13 +61,13 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
// path = /api/v0/users
// ^^^^^^^ this much is matched and the length is 7.
let mut matching_prefix_length = Some(0);
for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
for item in zip_longest(segments(path_and_query.path()), unescaped_segments(prefix)) {
// count the `/`
*matching_prefix_length.as_mut().unwrap() += 1;
match item {
Item::Both(path_segment, prefix_segment) => {
if is_capture(prefix_segment) || path_segment == prefix_segment {
if is_capture(&prefix_segment) || path_segment == prefix_segment {
// the prefix segment is either a param, which matches anything, or
// it actually matches the path segment
*matching_prefix_length.as_mut().unwrap() += path_segment.len();
@ -121,7 +122,7 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
Some(Uri::from_parts(parts).unwrap())
}
fn segments(s: &str) -> impl Iterator<Item = &str> {
fn segments(s: &str) -> impl Iterator<Item = Cow<'_, str>> {
assert!(
s.starts_with('/'),
"path didn't start with '/'. axum should have caught this higher up."
@ -131,6 +132,19 @@ fn segments(s: &str) -> impl Iterator<Item = &str> {
// skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"]
// otherwise
.skip(1)
.map(Cow::Borrowed)
}
/// This unescapes anything handled specially by `matchit`.
/// Currently, that means only `{{` and `}}` to mean literal `{` and `}` respectively.
fn unescaped_segments(s: &str) -> impl Iterator<Item = Cow<'_, str>> {
segments(s).map(|segment| {
if segment.contains("{{") || segment.contains("}}") {
Cow::Owned(segment.replace("{{", "{").replace("}}", "}"))
} else {
segment
}
})
}
fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
@ -380,6 +394,48 @@ mod tests {
expected = Some("/a"),
);
test!(
braces_1,
uri = "/{a}/a",
prefix = "/{{a}}/",
expected = Some("/a"),
);
test!(
braces_2,
uri = "/{a}/b",
prefix = "/{param}",
expected = Some("/b"),
);
test!(
braces_3,
uri = "/{a}/{b}",
prefix = "/{{a}}/{{b}}",
expected = Some("/"),
);
test!(
braces_4,
uri = "/{a}/{b}",
prefix = "/{{a}}/{b}",
expected = Some("/"),
);
test!(
braces_5,
uri = "/a/{b}",
prefix = "/a",
expected = Some("/{b}"),
);
test!(
braces_6,
uri = "/a/{b}",
prefix = "/{a}/{{b}}",
expected = Some("/"),
);
#[quickcheck]
fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
let UriAndPrefix { uri, prefix } = uri_and_prefix;

View file

@ -1118,3 +1118,117 @@ async fn colon_in_route() {
async fn asterisk_in_route() {
_ = Router::<()>::new().route("/*foo", get(|| async move {}));
}
#[crate::test]
async fn colon_in_route_allowed() {
let app = Router::<()>::new()
.without_v07_checks()
.route("/:foo", get(|| async move {}));
let client = TestClient::new(app);
let res = client.get("/:foo").await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.get("/foo").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[crate::test]
async fn asterisk_in_route_allowed() {
let app = Router::<()>::new()
.without_v07_checks()
.route("/*foo", get(|| async move {}));
let client = TestClient::new(app);
let res = client.get("/*foo").await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.get("/foo").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[crate::test]
async fn percent_encoding() {
let app = Router::new().route("/api", get(|| async { "api" }));
let client = TestClient::new(app);
let res = client.get("/%61pi").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "api");
}
#[crate::test]
async fn percent_encoding_slash() {
let app = Router::new()
.route("/slash/%2flash", get(|| async { "lower" }))
.route("/slash/%2Flash", get(|| async { "upper" }))
.route("/slash//lash", get(|| async { "/" }))
.route("/api/user", get(|| async { "user" }))
.route(
"/{capture}",
get(|Path(capture): Path<String>| {
assert_eq!(capture, "api/user");
ready("capture")
}),
);
let client = TestClient::new(app);
// %2f encodes `/`
let res = client.get("/api%2fuser").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "capture");
let res = client.get("/slash/%2flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "lower");
let res = client.get("/slash/%2Flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "upper");
// `%25` encodes `%`
// This must not be decoded twice
let res = client.get("/slash/%252flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "lower");
let res = client.get("/slash/%252Flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "upper");
}
#[crate::test]
async fn percent_encoding_percent() {
let app = Router::new()
.route("/%61pi", get(|| async { "percent" }))
.route("/api", get(|| async { "api" }));
let client = TestClient::new(app);
let res = client.get("/api").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "api");
let res = client.get("/%61pi").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "api");
// `%25` encodes `%`
// This must not be decoded twice, otherwise it will become `/api`
let res = client.get("/%2561pi").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "percent");
}

View file

@ -1,46 +1,32 @@
use crate::util::PercentDecodedStr;
use http::Extensions;
use matchit::Params;
use std::sync::Arc;
#[derive(Clone)]
pub(crate) enum UrlParams {
Params(Vec<(Arc<str>, PercentDecodedStr)>),
InvalidUtf8InPathParam { key: Arc<str> },
Params(Vec<(Arc<str>, Arc<str>)>),
}
pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) {
let current_params = extensions.get_mut();
if let Some(UrlParams::InvalidUtf8InPathParam { .. }) = current_params {
// nothing to do here since an error was stored earlier
return;
}
let params = params
.iter()
.filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM))
.filter(|(key, _)| !key.starts_with(super::FALLBACK_PARAM))
.map(|(k, v)| {
if let Some(decoded) = PercentDecodedStr::new(v) {
Ok((Arc::from(k), decoded))
} else {
Err(Arc::from(k))
}
(
Arc::from(k),
Arc::from(v.replace("%2f", "/").replace("%2F", "/")),
)
})
.collect::<Result<Vec<_>, _>>();
.collect::<Vec<_>>();
match (current_params, params) {
(Some(UrlParams::InvalidUtf8InPathParam { .. }), _) => {
unreachable!("we check for this state earlier in this method")
}
(_, Err(invalid_key)) => {
extensions.insert(UrlParams::InvalidUtf8InPathParam { key: invalid_key });
}
(Some(UrlParams::Params(current)), Ok(params)) => {
(Some(UrlParams::Params(current)), params) => {
current.extend(params);
}
(None, Ok(params)) => {
(None, params) => {
extensions.insert(UrlParams::Params(params));
}
}

View file

@ -1,36 +1,7 @@
use pin_project_lite::pin_project;
use std::{ops::Deref, sync::Arc};
pub(crate) use self::mutex::*;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct PercentDecodedStr(Arc<str>);
impl PercentDecodedStr {
pub(crate) fn new<S>(s: S) -> Option<Self>
where
S: AsRef<str>,
{
percent_encoding::percent_decode(s.as_ref().as_bytes())
.decode_utf8()
.ok()
.map(|decoded| Self(decoded.as_ref().into()))
}
pub(crate) fn as_str(&self) -> &str {
&self.0
}
}
impl Deref for PercentDecodedStr {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
pin_project! {
#[project = EitherProj]
pub(crate) enum Either<A, B> {