Fix nesting of opaque services that paths that contain params (#841)

* checkpoint

* fix it

* more consistent macros

* fix msrv
This commit is contained in:
David Pedersen 2022-03-08 21:27:44 +01:00 committed by GitHub
parent 0600eff31a
commit dfa2e3b09f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 401 additions and 29 deletions

View file

@ -4,5 +4,5 @@ members = [
"axum-core",
"axum-extra",
"axum-macros",
"examples/*",
# "examples/*",
]

View file

@ -65,6 +65,8 @@ tokio-stream = "0.1"
tracing = "0.1"
uuid = { version = "0.8", features = ["serde", "v4"] }
anyhow = "1.0"
quickcheck = "1.0"
quickcheck_macros = "1.0"
[dev-dependencies.tower]
package = "tower"

View file

@ -1,6 +1,5 @@
use http::{Request, Uri};
use std::{
borrow::Cow,
sync::Arc,
task::{Context, Poll},
};
@ -35,43 +34,396 @@ where
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
let new_uri = strip_prefix(req.uri(), &self.prefix);
*req.uri_mut() = new_uri;
if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) {
*req.uri_mut() = new_uri;
}
self.inner.call(req)
}
}
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
path
fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
let path_and_query: http::uri::PathAndQuery = if let Some(path_and_query) = uri.path_and_query()
{
// Check whether the prefix matches the path and if so how long the matching prefix is.
//
// # Examples
//
// prefix = /api
// uri = /api/users
// ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4
// characters we get the remainder
//
// prefix = /api/:version
// uri = /api/v0/users
// ^^^^^^^ this much is matched and the length is 7.
let mut matching_prefix_length = Some(0);
for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
// count the `/`
*matching_prefix_length.as_mut().unwrap() += 1;
match item {
Item::Both(path_segment, prefix_segment) => {
if prefix_segment.starts_with(':') || path_segment == prefix_segment {
*matching_prefix_length.as_mut().unwrap() += path_segment.len();
} else {
matching_prefix_length = None;
break;
}
}
Item::First(_) => {
break;
}
Item::Second(_) => {
matching_prefix_length = None;
break;
}
}
}
let after_prefix = if let Some(idx) = matching_prefix_length {
uri.path().split_at(idx).1
} else {
path_and_query.path()
return None;
};
let new_path = if new_path.starts_with('/') {
Cow::Borrowed(new_path)
} else {
Cow::Owned(format!("/{}", new_path))
};
if let Some(query) = path_and_query.query() {
Some(
format!("{}?{}", new_path, query)
.parse::<http::uri::PathAndQuery>()
.unwrap(),
)
} else {
Some(new_path.parse().unwrap())
match (after_prefix.starts_with('/'), path_and_query.query()) {
(true, None) => after_prefix.parse().unwrap(),
(true, Some(query)) => format!("{}?{}", after_prefix, query).parse().unwrap(),
(false, None) => format!("/{}", after_prefix).parse().unwrap(),
(false, Some(query)) => format!("/{}?{}", after_prefix, query).parse().unwrap(),
}
} else {
None
return None;
};
let mut parts = http::uri::Parts::default();
parts.scheme = uri.scheme().cloned();
parts.authority = uri.authority().cloned();
parts.path_and_query = path_and_query;
let mut parts = uri.clone().into_parts();
parts.path_and_query = Some(path_and_query);
Uri::from_parts(parts).unwrap()
Some(Uri::from_parts(parts).unwrap())
}
fn segments(s: &str) -> impl Iterator<Item = &str> {
assert!(
s.starts_with('/'),
"path didn't start with '/'. axum should have caught this higher up."
);
s.split('/')
// skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"]
// otherwise
.skip(1)
}
fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
where
I: Iterator,
I2: Iterator<Item = I::Item>,
{
let a = a.map(Some).chain(std::iter::repeat_with(|| None));
let b = b.map(Some).chain(std::iter::repeat_with(|| None));
a.zip(b)
// use `map_while` when its stable in our MSRV
.take_while(|(a, b)| a.is_some() || b.is_some())
.filter_map(|(a, b)| match (a, b) {
(Some(a), Some(b)) => Some(Item::Both(a, b)),
(Some(a), None) => Some(Item::First(a)),
(None, Some(b)) => Some(Item::Second(b)),
(None, None) => unreachable!("take_while removes these"),
})
}
#[derive(Debug)]
enum Item<T> {
Both(T, T),
First(T),
Second(T),
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use quickcheck::Arbitrary;
use quickcheck_macros::quickcheck;
macro_rules! test {
(
$name:ident,
uri = $uri:literal,
prefix = $prefix:literal,
expected = $expected:expr,
) => {
#[test]
fn $name() {
let uri = $uri.parse().unwrap();
let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string());
assert_eq!(new_uri.as_deref(), $expected);
}
};
}
test!(empty, uri = "/", prefix = "/", expected = Some("/"),);
test!(
single_segment,
uri = "/a",
prefix = "/a",
expected = Some("/"),
);
test!(
single_segment_root_uri,
uri = "/",
prefix = "/a",
expected = None,
);
// the prefix is empty, so removing it should have no effect
test!(
single_segment_root_prefix,
uri = "/a",
prefix = "/",
expected = None,
);
test!(
single_segment_no_match,
uri = "/a",
prefix = "/b",
expected = None,
);
test!(
single_segment_trailing_slash,
uri = "/a/",
prefix = "/a/",
expected = Some("/"),
);
test!(
single_segment_trailing_slash_2,
uri = "/a",
prefix = "/a/",
expected = None,
);
test!(
single_segment_trailing_slash_3,
uri = "/a/",
prefix = "/a",
expected = Some("/"),
);
test!(
multi_segment,
uri = "/a/b",
prefix = "/a",
expected = Some("/b"),
);
test!(
multi_segment_2,
uri = "/b/a",
prefix = "/a",
expected = None,
);
test!(
multi_segment_3,
uri = "/a",
prefix = "/a/b",
expected = None,
);
test!(
multi_segment_4,
uri = "/a/b",
prefix = "/b",
expected = None,
);
test!(
multi_segment_trailing_slash,
uri = "/a/b/",
prefix = "/a/b/",
expected = Some("/"),
);
test!(
multi_segment_trailing_slash_2,
uri = "/a/b",
prefix = "/a/b/",
expected = None,
);
test!(
multi_segment_trailing_slash_3,
uri = "/a/b/",
prefix = "/a/b",
expected = Some("/"),
);
test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),);
test!(
param_1,
uri = "/a",
prefix = "/:param",
expected = Some("/"),
);
test!(
param_2,
uri = "/a/b",
prefix = "/:param",
expected = Some("/b"),
);
test!(
param_3,
uri = "/b/a",
prefix = "/:param",
expected = Some("/a"),
);
test!(
param_4,
uri = "/a/b",
prefix = "/a/:param",
expected = Some("/"),
);
test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,);
test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,);
test!(
param_7,
uri = "/b/a",
prefix = "/:param/a",
expected = Some("/"),
);
test!(
param_8,
uri = "/a/b/c",
prefix = "/a/:param/c",
expected = Some("/"),
);
test!(
param_9,
uri = "/c/b/a",
prefix = "/a/:param/c",
expected = None,
);
test!(
param_10,
uri = "/a/",
prefix = "/:param",
expected = Some("/"),
);
test!(param_11, uri = "/a", prefix = "/:param/", expected = None,);
test!(
param_12,
uri = "/a/",
prefix = "/:param/",
expected = Some("/"),
);
#[quickcheck]
fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
let UriAndPrefix { uri, prefix } = uri_and_prefix;
strip_prefix(&uri, &prefix);
true
}
#[derive(Clone, Debug)]
struct UriAndPrefix {
uri: Uri,
prefix: String,
}
impl Arbitrary for UriAndPrefix {
fn arbitrary(g: &mut quickcheck::Gen) -> Self {
let mut uri = String::new();
let mut prefix = String::new();
let size = u8_between(1, 20, g);
for _ in 0..size {
let segment = ascii_alphanumeric(g);
uri.push('/');
uri.push_str(&segment);
prefix.push('/');
let make_matching_segment = bool::arbitrary(g);
let make_capture = bool::arbitrary(g);
match (make_matching_segment, make_capture) {
(_, true) => {
prefix.push_str(":a");
}
(true, false) => {
prefix.push_str(&segment);
}
(false, false) => {
prefix.push_str(&ascii_alphanumeric(g));
}
}
}
if bool::arbitrary(g) {
uri.push('/');
}
if bool::arbitrary(g) {
prefix.push('/');
}
Self {
uri: uri.parse().unwrap(),
prefix,
}
}
}
fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String {
#[derive(Clone)]
struct AsciiAlphanumeric(String);
impl Arbitrary for AsciiAlphanumeric {
fn arbitrary(g: &mut quickcheck::Gen) -> Self {
let mut out = String::new();
let size = u8_between(1, 20, g) as usize;
while out.len() < size {
let c = char::arbitrary(g);
if c.is_ascii_alphanumeric() {
out.push(c);
}
}
Self(out)
}
}
let out = AsciiAlphanumeric::arbitrary(g).0;
assert!(!out.is_empty());
out
}
fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 {
loop {
let size = u8::arbitrary(g);
if size > lower && size <= upper {
break size;
}
}
}
}

View file

@ -346,6 +346,24 @@ async fn outer_middleware_still_see_whole_url() {
assert_eq!(client.get("/one/two").send().await.text().await, "/one/two");
}
#[tokio::test]
async fn nest_at_capture() {
let api_routes = Router::new()
.route(
"/:b",
get(|Path((a, b)): Path<(String, String)>| async move { format!("a={} b={}", a, b) }),
)
.boxed_clone();
let app = Router::new().nest("/:a", api_routes);
let client = TestClient::new(app);
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "a=foo b=bar");
}
macro_rules! nested_route_test {
(
$name:ident,