mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-25 00:28:07 +01:00
Fix nesting of opaque services that paths that contain params (#841)
* checkpoint * fix it * more consistent macros * fix msrv
This commit is contained in:
parent
0600eff31a
commit
dfa2e3b09f
4 changed files with 401 additions and 29 deletions
|
@ -4,5 +4,5 @@ members = [
|
|||
"axum-core",
|
||||
"axum-extra",
|
||||
"axum-macros",
|
||||
"examples/*",
|
||||
# "examples/*",
|
||||
]
|
||||
|
|
|
@ -65,6 +65,8 @@ tokio-stream = "0.1"
|
|||
tracing = "0.1"
|
||||
uuid = { version = "0.8", features = ["serde", "v4"] }
|
||||
anyhow = "1.0"
|
||||
quickcheck = "1.0"
|
||||
quickcheck_macros = "1.0"
|
||||
|
||||
[dev-dependencies.tower]
|
||||
package = "tower"
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use http::{Request, Uri};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
@ -35,43 +34,396 @@ where
|
|||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||
let new_uri = strip_prefix(req.uri(), &self.prefix);
|
||||
*req.uri_mut() = new_uri;
|
||||
if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) {
|
||||
*req.uri_mut() = new_uri;
|
||||
}
|
||||
self.inner.call(req)
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
|
||||
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
|
||||
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
|
||||
path
|
||||
fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
|
||||
let path_and_query: http::uri::PathAndQuery = if let Some(path_and_query) = uri.path_and_query()
|
||||
{
|
||||
// Check whether the prefix matches the path and if so how long the matching prefix is.
|
||||
//
|
||||
// # Examples
|
||||
//
|
||||
// prefix = /api
|
||||
// uri = /api/users
|
||||
// ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4
|
||||
// characters we get the remainder
|
||||
//
|
||||
// prefix = /api/:version
|
||||
// uri = /api/v0/users
|
||||
// ^^^^^^^ this much is matched and the length is 7.
|
||||
let mut matching_prefix_length = Some(0);
|
||||
for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
|
||||
// count the `/`
|
||||
*matching_prefix_length.as_mut().unwrap() += 1;
|
||||
|
||||
match item {
|
||||
Item::Both(path_segment, prefix_segment) => {
|
||||
if prefix_segment.starts_with(':') || path_segment == prefix_segment {
|
||||
*matching_prefix_length.as_mut().unwrap() += path_segment.len();
|
||||
} else {
|
||||
matching_prefix_length = None;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Item::First(_) => {
|
||||
break;
|
||||
}
|
||||
Item::Second(_) => {
|
||||
matching_prefix_length = None;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let after_prefix = if let Some(idx) = matching_prefix_length {
|
||||
uri.path().split_at(idx).1
|
||||
} else {
|
||||
path_and_query.path()
|
||||
return None;
|
||||
};
|
||||
|
||||
let new_path = if new_path.starts_with('/') {
|
||||
Cow::Borrowed(new_path)
|
||||
} else {
|
||||
Cow::Owned(format!("/{}", new_path))
|
||||
};
|
||||
|
||||
if let Some(query) = path_and_query.query() {
|
||||
Some(
|
||||
format!("{}?{}", new_path, query)
|
||||
.parse::<http::uri::PathAndQuery>()
|
||||
.unwrap(),
|
||||
)
|
||||
} else {
|
||||
Some(new_path.parse().unwrap())
|
||||
match (after_prefix.starts_with('/'), path_and_query.query()) {
|
||||
(true, None) => after_prefix.parse().unwrap(),
|
||||
(true, Some(query)) => format!("{}?{}", after_prefix, query).parse().unwrap(),
|
||||
(false, None) => format!("/{}", after_prefix).parse().unwrap(),
|
||||
(false, Some(query)) => format!("/{}?{}", after_prefix, query).parse().unwrap(),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
return None;
|
||||
};
|
||||
|
||||
let mut parts = http::uri::Parts::default();
|
||||
parts.scheme = uri.scheme().cloned();
|
||||
parts.authority = uri.authority().cloned();
|
||||
parts.path_and_query = path_and_query;
|
||||
let mut parts = uri.clone().into_parts();
|
||||
parts.path_and_query = Some(path_and_query);
|
||||
|
||||
Uri::from_parts(parts).unwrap()
|
||||
Some(Uri::from_parts(parts).unwrap())
|
||||
}
|
||||
|
||||
fn segments(s: &str) -> impl Iterator<Item = &str> {
|
||||
assert!(
|
||||
s.starts_with('/'),
|
||||
"path didn't start with '/'. axum should have caught this higher up."
|
||||
);
|
||||
|
||||
s.split('/')
|
||||
// skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"]
|
||||
// otherwise
|
||||
.skip(1)
|
||||
}
|
||||
|
||||
fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
|
||||
where
|
||||
I: Iterator,
|
||||
I2: Iterator<Item = I::Item>,
|
||||
{
|
||||
let a = a.map(Some).chain(std::iter::repeat_with(|| None));
|
||||
let b = b.map(Some).chain(std::iter::repeat_with(|| None));
|
||||
a.zip(b)
|
||||
// use `map_while` when its stable in our MSRV
|
||||
.take_while(|(a, b)| a.is_some() || b.is_some())
|
||||
.filter_map(|(a, b)| match (a, b) {
|
||||
(Some(a), Some(b)) => Some(Item::Both(a, b)),
|
||||
(Some(a), None) => Some(Item::First(a)),
|
||||
(None, Some(b)) => Some(Item::Second(b)),
|
||||
(None, None) => unreachable!("take_while removes these"),
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Item<T> {
|
||||
Both(T, T),
|
||||
First(T),
|
||||
Second(T),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[allow(unused_imports)]
|
||||
use super::*;
|
||||
use quickcheck::Arbitrary;
|
||||
use quickcheck_macros::quickcheck;
|
||||
|
||||
macro_rules! test {
|
||||
(
|
||||
$name:ident,
|
||||
uri = $uri:literal,
|
||||
prefix = $prefix:literal,
|
||||
expected = $expected:expr,
|
||||
) => {
|
||||
#[test]
|
||||
fn $name() {
|
||||
let uri = $uri.parse().unwrap();
|
||||
let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string());
|
||||
assert_eq!(new_uri.as_deref(), $expected);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
test!(empty, uri = "/", prefix = "/", expected = Some("/"),);
|
||||
|
||||
test!(
|
||||
single_segment,
|
||||
uri = "/a",
|
||||
prefix = "/a",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(
|
||||
single_segment_root_uri,
|
||||
uri = "/",
|
||||
prefix = "/a",
|
||||
expected = None,
|
||||
);
|
||||
|
||||
// the prefix is empty, so removing it should have no effect
|
||||
test!(
|
||||
single_segment_root_prefix,
|
||||
uri = "/a",
|
||||
prefix = "/",
|
||||
expected = None,
|
||||
);
|
||||
|
||||
test!(
|
||||
single_segment_no_match,
|
||||
uri = "/a",
|
||||
prefix = "/b",
|
||||
expected = None,
|
||||
);
|
||||
|
||||
test!(
|
||||
single_segment_trailing_slash,
|
||||
uri = "/a/",
|
||||
prefix = "/a/",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(
|
||||
single_segment_trailing_slash_2,
|
||||
uri = "/a",
|
||||
prefix = "/a/",
|
||||
expected = None,
|
||||
);
|
||||
|
||||
test!(
|
||||
single_segment_trailing_slash_3,
|
||||
uri = "/a/",
|
||||
prefix = "/a",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(
|
||||
multi_segment,
|
||||
uri = "/a/b",
|
||||
prefix = "/a",
|
||||
expected = Some("/b"),
|
||||
);
|
||||
|
||||
test!(
|
||||
multi_segment_2,
|
||||
uri = "/b/a",
|
||||
prefix = "/a",
|
||||
expected = None,
|
||||
);
|
||||
|
||||
test!(
|
||||
multi_segment_3,
|
||||
uri = "/a",
|
||||
prefix = "/a/b",
|
||||
expected = None,
|
||||
);
|
||||
|
||||
test!(
|
||||
multi_segment_4,
|
||||
uri = "/a/b",
|
||||
prefix = "/b",
|
||||
expected = None,
|
||||
);
|
||||
|
||||
test!(
|
||||
multi_segment_trailing_slash,
|
||||
uri = "/a/b/",
|
||||
prefix = "/a/b/",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(
|
||||
multi_segment_trailing_slash_2,
|
||||
uri = "/a/b",
|
||||
prefix = "/a/b/",
|
||||
expected = None,
|
||||
);
|
||||
|
||||
test!(
|
||||
multi_segment_trailing_slash_3,
|
||||
uri = "/a/b/",
|
||||
prefix = "/a/b",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),);
|
||||
|
||||
test!(
|
||||
param_1,
|
||||
uri = "/a",
|
||||
prefix = "/:param",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(
|
||||
param_2,
|
||||
uri = "/a/b",
|
||||
prefix = "/:param",
|
||||
expected = Some("/b"),
|
||||
);
|
||||
|
||||
test!(
|
||||
param_3,
|
||||
uri = "/b/a",
|
||||
prefix = "/:param",
|
||||
expected = Some("/a"),
|
||||
);
|
||||
|
||||
test!(
|
||||
param_4,
|
||||
uri = "/a/b",
|
||||
prefix = "/a/:param",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,);
|
||||
|
||||
test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,);
|
||||
|
||||
test!(
|
||||
param_7,
|
||||
uri = "/b/a",
|
||||
prefix = "/:param/a",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(
|
||||
param_8,
|
||||
uri = "/a/b/c",
|
||||
prefix = "/a/:param/c",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(
|
||||
param_9,
|
||||
uri = "/c/b/a",
|
||||
prefix = "/a/:param/c",
|
||||
expected = None,
|
||||
);
|
||||
|
||||
test!(
|
||||
param_10,
|
||||
uri = "/a/",
|
||||
prefix = "/:param",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
test!(param_11, uri = "/a", prefix = "/:param/", expected = None,);
|
||||
|
||||
test!(
|
||||
param_12,
|
||||
uri = "/a/",
|
||||
prefix = "/:param/",
|
||||
expected = Some("/"),
|
||||
);
|
||||
|
||||
#[quickcheck]
|
||||
fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
|
||||
let UriAndPrefix { uri, prefix } = uri_and_prefix;
|
||||
strip_prefix(&uri, &prefix);
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct UriAndPrefix {
|
||||
uri: Uri,
|
||||
prefix: String,
|
||||
}
|
||||
|
||||
impl Arbitrary for UriAndPrefix {
|
||||
fn arbitrary(g: &mut quickcheck::Gen) -> Self {
|
||||
let mut uri = String::new();
|
||||
let mut prefix = String::new();
|
||||
|
||||
let size = u8_between(1, 20, g);
|
||||
|
||||
for _ in 0..size {
|
||||
let segment = ascii_alphanumeric(g);
|
||||
|
||||
uri.push('/');
|
||||
uri.push_str(&segment);
|
||||
|
||||
prefix.push('/');
|
||||
|
||||
let make_matching_segment = bool::arbitrary(g);
|
||||
let make_capture = bool::arbitrary(g);
|
||||
|
||||
match (make_matching_segment, make_capture) {
|
||||
(_, true) => {
|
||||
prefix.push_str(":a");
|
||||
}
|
||||
(true, false) => {
|
||||
prefix.push_str(&segment);
|
||||
}
|
||||
(false, false) => {
|
||||
prefix.push_str(&ascii_alphanumeric(g));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bool::arbitrary(g) {
|
||||
uri.push('/');
|
||||
}
|
||||
|
||||
if bool::arbitrary(g) {
|
||||
prefix.push('/');
|
||||
}
|
||||
|
||||
Self {
|
||||
uri: uri.parse().unwrap(),
|
||||
prefix,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String {
|
||||
#[derive(Clone)]
|
||||
struct AsciiAlphanumeric(String);
|
||||
|
||||
impl Arbitrary for AsciiAlphanumeric {
|
||||
fn arbitrary(g: &mut quickcheck::Gen) -> Self {
|
||||
let mut out = String::new();
|
||||
|
||||
let size = u8_between(1, 20, g) as usize;
|
||||
|
||||
while out.len() < size {
|
||||
let c = char::arbitrary(g);
|
||||
if c.is_ascii_alphanumeric() {
|
||||
out.push(c);
|
||||
}
|
||||
}
|
||||
Self(out)
|
||||
}
|
||||
}
|
||||
|
||||
let out = AsciiAlphanumeric::arbitrary(g).0;
|
||||
assert!(!out.is_empty());
|
||||
out
|
||||
}
|
||||
|
||||
fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 {
|
||||
loop {
|
||||
let size = u8::arbitrary(g);
|
||||
if size > lower && size <= upper {
|
||||
break size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -346,6 +346,24 @@ async fn outer_middleware_still_see_whole_url() {
|
|||
assert_eq!(client.get("/one/two").send().await.text().await, "/one/two");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nest_at_capture() {
|
||||
let api_routes = Router::new()
|
||||
.route(
|
||||
"/:b",
|
||||
get(|Path((a, b)): Path<(String, String)>| async move { format!("a={} b={}", a, b) }),
|
||||
)
|
||||
.boxed_clone();
|
||||
|
||||
let app = Router::new().nest("/:a", api_routes);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/bar").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await, "a=foo b=bar");
|
||||
}
|
||||
|
||||
macro_rules! nested_route_test {
|
||||
(
|
||||
$name:ident,
|
||||
|
|
Loading…
Reference in a new issue