From dfa2e3b09f2d20bb6ad00998a46a87b38aab6414 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 8 Mar 2022 21:27:44 +0100 Subject: [PATCH] Fix nesting of opaque services that paths that contain params (#841) * checkpoint * fix it * more consistent macros * fix msrv --- Cargo.toml | 2 +- axum/Cargo.toml | 2 + axum/src/routing/strip_prefix.rs | 408 ++++++++++++++++++++++++++++--- axum/src/routing/tests/nest.rs | 18 ++ 4 files changed, 401 insertions(+), 29 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5120e29c..33d40ae4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,5 +4,5 @@ members = [ "axum-core", "axum-extra", "axum-macros", - "examples/*", + # "examples/*", ] diff --git a/axum/Cargo.toml b/axum/Cargo.toml index ef0e276e..10e4ddcc 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -65,6 +65,8 @@ tokio-stream = "0.1" tracing = "0.1" uuid = { version = "0.8", features = ["serde", "v4"] } anyhow = "1.0" +quickcheck = "1.0" +quickcheck_macros = "1.0" [dev-dependencies.tower] package = "tower" diff --git a/axum/src/routing/strip_prefix.rs b/axum/src/routing/strip_prefix.rs index c7950538..42b41ace 100644 --- a/axum/src/routing/strip_prefix.rs +++ b/axum/src/routing/strip_prefix.rs @@ -1,6 +1,5 @@ use http::{Request, Uri}; use std::{ - borrow::Cow, sync::Arc, task::{Context, Poll}, }; @@ -35,43 +34,396 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - let new_uri = strip_prefix(req.uri(), &self.prefix); - *req.uri_mut() = new_uri; + if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) { + *req.uri_mut() = new_uri; + } self.inner.call(req) } } -fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { - let path_and_query = if let Some(path_and_query) = uri.path_and_query() { - let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { - path +fn strip_prefix(uri: &Uri, prefix: &str) -> Option { + let path_and_query: http::uri::PathAndQuery = if let Some(path_and_query) = uri.path_and_query() + { + // Check whether the prefix matches the path and if so how long the matching prefix is. + // + // # Examples + // + // prefix = /api + // uri = /api/users + // ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4 + // characters we get the remainder + // + // prefix = /api/:version + // uri = /api/v0/users + // ^^^^^^^ this much is matched and the length is 7. + let mut matching_prefix_length = Some(0); + for item in zip_longest(segments(path_and_query.path()), segments(prefix)) { + // count the `/` + *matching_prefix_length.as_mut().unwrap() += 1; + + match item { + Item::Both(path_segment, prefix_segment) => { + if prefix_segment.starts_with(':') || path_segment == prefix_segment { + *matching_prefix_length.as_mut().unwrap() += path_segment.len(); + } else { + matching_prefix_length = None; + break; + } + } + Item::First(_) => { + break; + } + Item::Second(_) => { + matching_prefix_length = None; + break; + } + } + } + + let after_prefix = if let Some(idx) = matching_prefix_length { + uri.path().split_at(idx).1 } else { - path_and_query.path() + return None; }; - let new_path = if new_path.starts_with('/') { - Cow::Borrowed(new_path) - } else { - Cow::Owned(format!("/{}", new_path)) - }; - - if let Some(query) = path_and_query.query() { - Some( - format!("{}?{}", new_path, query) - .parse::() - .unwrap(), - ) - } else { - Some(new_path.parse().unwrap()) + match (after_prefix.starts_with('/'), path_and_query.query()) { + (true, None) => after_prefix.parse().unwrap(), + (true, Some(query)) => format!("{}?{}", after_prefix, query).parse().unwrap(), + (false, None) => format!("/{}", after_prefix).parse().unwrap(), + (false, Some(query)) => format!("/{}?{}", after_prefix, query).parse().unwrap(), } } else { - None + return None; }; - let mut parts = http::uri::Parts::default(); - parts.scheme = uri.scheme().cloned(); - parts.authority = uri.authority().cloned(); - parts.path_and_query = path_and_query; + let mut parts = uri.clone().into_parts(); + parts.path_and_query = Some(path_and_query); - Uri::from_parts(parts).unwrap() + Some(Uri::from_parts(parts).unwrap()) +} + +fn segments(s: &str) -> impl Iterator { + assert!( + s.starts_with('/'), + "path didn't start with '/'. axum should have caught this higher up." + ); + + s.split('/') + // skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"] + // otherwise + .skip(1) +} + +fn zip_longest(a: I, b: I2) -> impl Iterator> +where + I: Iterator, + I2: Iterator, +{ + let a = a.map(Some).chain(std::iter::repeat_with(|| None)); + let b = b.map(Some).chain(std::iter::repeat_with(|| None)); + a.zip(b) + // use `map_while` when its stable in our MSRV + .take_while(|(a, b)| a.is_some() || b.is_some()) + .filter_map(|(a, b)| match (a, b) { + (Some(a), Some(b)) => Some(Item::Both(a, b)), + (Some(a), None) => Some(Item::First(a)), + (None, Some(b)) => Some(Item::Second(b)), + (None, None) => unreachable!("take_while removes these"), + }) +} + +#[derive(Debug)] +enum Item { + Both(T, T), + First(T), + Second(T), +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use quickcheck::Arbitrary; + use quickcheck_macros::quickcheck; + + macro_rules! test { + ( + $name:ident, + uri = $uri:literal, + prefix = $prefix:literal, + expected = $expected:expr, + ) => { + #[test] + fn $name() { + let uri = $uri.parse().unwrap(); + let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string()); + assert_eq!(new_uri.as_deref(), $expected); + } + }; + } + + test!(empty, uri = "/", prefix = "/", expected = Some("/"),); + + test!( + single_segment, + uri = "/a", + prefix = "/a", + expected = Some("/"), + ); + + test!( + single_segment_root_uri, + uri = "/", + prefix = "/a", + expected = None, + ); + + // the prefix is empty, so removing it should have no effect + test!( + single_segment_root_prefix, + uri = "/a", + prefix = "/", + expected = None, + ); + + test!( + single_segment_no_match, + uri = "/a", + prefix = "/b", + expected = None, + ); + + test!( + single_segment_trailing_slash, + uri = "/a/", + prefix = "/a/", + expected = Some("/"), + ); + + test!( + single_segment_trailing_slash_2, + uri = "/a", + prefix = "/a/", + expected = None, + ); + + test!( + single_segment_trailing_slash_3, + uri = "/a/", + prefix = "/a", + expected = Some("/"), + ); + + test!( + multi_segment, + uri = "/a/b", + prefix = "/a", + expected = Some("/b"), + ); + + test!( + multi_segment_2, + uri = "/b/a", + prefix = "/a", + expected = None, + ); + + test!( + multi_segment_3, + uri = "/a", + prefix = "/a/b", + expected = None, + ); + + test!( + multi_segment_4, + uri = "/a/b", + prefix = "/b", + expected = None, + ); + + test!( + multi_segment_trailing_slash, + uri = "/a/b/", + prefix = "/a/b/", + expected = Some("/"), + ); + + test!( + multi_segment_trailing_slash_2, + uri = "/a/b", + prefix = "/a/b/", + expected = None, + ); + + test!( + multi_segment_trailing_slash_3, + uri = "/a/b/", + prefix = "/a/b", + expected = Some("/"), + ); + + test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),); + + test!( + param_1, + uri = "/a", + prefix = "/:param", + expected = Some("/"), + ); + + test!( + param_2, + uri = "/a/b", + prefix = "/:param", + expected = Some("/b"), + ); + + test!( + param_3, + uri = "/b/a", + prefix = "/:param", + expected = Some("/a"), + ); + + test!( + param_4, + uri = "/a/b", + prefix = "/a/:param", + expected = Some("/"), + ); + + test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,); + + test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,); + + test!( + param_7, + uri = "/b/a", + prefix = "/:param/a", + expected = Some("/"), + ); + + test!( + param_8, + uri = "/a/b/c", + prefix = "/a/:param/c", + expected = Some("/"), + ); + + test!( + param_9, + uri = "/c/b/a", + prefix = "/a/:param/c", + expected = None, + ); + + test!( + param_10, + uri = "/a/", + prefix = "/:param", + expected = Some("/"), + ); + + test!(param_11, uri = "/a", prefix = "/:param/", expected = None,); + + test!( + param_12, + uri = "/a/", + prefix = "/:param/", + expected = Some("/"), + ); + + #[quickcheck] + fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool { + let UriAndPrefix { uri, prefix } = uri_and_prefix; + strip_prefix(&uri, &prefix); + true + } + + #[derive(Clone, Debug)] + struct UriAndPrefix { + uri: Uri, + prefix: String, + } + + impl Arbitrary for UriAndPrefix { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + let mut uri = String::new(); + let mut prefix = String::new(); + + let size = u8_between(1, 20, g); + + for _ in 0..size { + let segment = ascii_alphanumeric(g); + + uri.push('/'); + uri.push_str(&segment); + + prefix.push('/'); + + let make_matching_segment = bool::arbitrary(g); + let make_capture = bool::arbitrary(g); + + match (make_matching_segment, make_capture) { + (_, true) => { + prefix.push_str(":a"); + } + (true, false) => { + prefix.push_str(&segment); + } + (false, false) => { + prefix.push_str(&ascii_alphanumeric(g)); + } + } + } + + if bool::arbitrary(g) { + uri.push('/'); + } + + if bool::arbitrary(g) { + prefix.push('/'); + } + + Self { + uri: uri.parse().unwrap(), + prefix, + } + } + } + + fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String { + #[derive(Clone)] + struct AsciiAlphanumeric(String); + + impl Arbitrary for AsciiAlphanumeric { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + let mut out = String::new(); + + let size = u8_between(1, 20, g) as usize; + + while out.len() < size { + let c = char::arbitrary(g); + if c.is_ascii_alphanumeric() { + out.push(c); + } + } + Self(out) + } + } + + let out = AsciiAlphanumeric::arbitrary(g).0; + assert!(!out.is_empty()); + out + } + + fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 { + loop { + let size = u8::arbitrary(g); + if size > lower && size <= upper { + break size; + } + } + } } diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index 569c8785..a311186c 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -346,6 +346,24 @@ async fn outer_middleware_still_see_whole_url() { assert_eq!(client.get("/one/two").send().await.text().await, "/one/two"); } +#[tokio::test] +async fn nest_at_capture() { + let api_routes = Router::new() + .route( + "/:b", + get(|Path((a, b)): Path<(String, String)>| async move { format!("a={} b={}", a, b) }), + ) + .boxed_clone(); + + let app = Router::new().nest("/:a", api_routes); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "a=foo b=bar"); +} + macro_rules! nested_route_test { ( $name:ident,