Fix routing issues when loading a Router via a dynamic library (#1806)

This commit is contained in:
David Pedersen 2023-03-03 13:23:53 +01:00 committed by GitHub
parent 6075be60ed
commit 5a58edac16
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 15 deletions

View file

@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased # Unreleased
- **fixed:** Add `#[must_use]` to `WebSocketUpgrade::on_upgrade` ([#1801]) - **fixed:** Add `#[must_use]` to `WebSocketUpgrade::on_upgrade` ([#1801])
- **fixed:** Fix routing issues when loading a `Router` via a dynamic library ([#1806])
[#1806]: https://github.com/tokio-rs/axum/pull/1806
[#1801]: https://github.com/tokio-rs/axum/pull/1801 [#1801]: https://github.com/tokio-rs/axum/pull/1801

View file

@ -47,24 +47,12 @@ pub use self::method_routing::{
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct RouteId(u32); pub(crate) struct RouteId(u32);
impl RouteId {
fn next() -> Self {
use std::sync::atomic::{AtomicU32, Ordering};
// `AtomicU64` isn't supported on all platforms
static ID: AtomicU32 = AtomicU32::new(0);
let id = ID.fetch_add(1, Ordering::Relaxed);
if id == u32::MAX {
panic!("Over `u32::MAX` routes created. If you need this, please file an issue.");
}
Self(id)
}
}
/// The router type for composing handlers and services. /// The router type for composing handlers and services.
pub struct Router<S = (), B = Body> { pub struct Router<S = (), B = Body> {
routes: HashMap<RouteId, Endpoint<S, B>>, routes: HashMap<RouteId, Endpoint<S, B>>,
node: Arc<Node>, node: Arc<Node>,
fallback: Fallback<S, B>, fallback: Fallback<S, B>,
prev_route_id: RouteId,
} }
impl<S, B> Clone for Router<S, B> { impl<S, B> Clone for Router<S, B> {
@ -73,6 +61,7 @@ impl<S, B> Clone for Router<S, B> {
routes: self.routes.clone(), routes: self.routes.clone(),
node: Arc::clone(&self.node), node: Arc::clone(&self.node),
fallback: self.fallback.clone(), fallback: self.fallback.clone(),
prev_route_id: self.prev_route_id,
} }
} }
} }
@ -117,6 +106,7 @@ where
routes: Default::default(), routes: Default::default(),
node: Default::default(), node: Default::default(),
fallback: Fallback::Default(Route::new(NotFound)), fallback: Fallback::Default(Route::new(NotFound)),
prev_route_id: RouteId(0),
} }
} }
@ -134,7 +124,7 @@ where
validate_path(path); validate_path(path);
let id = RouteId::next(); let id = self.next_route_id();
let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
.node .node
@ -189,7 +179,7 @@ where
panic!("Paths must start with a `/`"); panic!("Paths must start with a `/`");
} }
let id = RouteId::next(); let id = self.next_route_id();
self.set_node(path, id); self.set_node(path, id);
self.routes.insert(id, endpoint); self.routes.insert(id, endpoint);
self self
@ -286,6 +276,7 @@ where
routes, routes,
node, node,
fallback, fallback,
prev_route_id: _,
} = other.into(); } = other.into();
for (id, route) in routes { for (id, route) in routes {
@ -335,6 +326,7 @@ where
routes, routes,
node: self.node, node: self.node,
fallback, fallback,
prev_route_id: self.prev_route_id,
} }
} }
@ -368,6 +360,7 @@ where
routes, routes,
node: self.node, node: self.node,
fallback: self.fallback, fallback: self.fallback,
prev_route_id: self.prev_route_id,
} }
} }
@ -419,6 +412,7 @@ where
routes, routes,
node: self.node, node: self.node,
fallback, fallback,
prev_route_id: self.prev_route_id,
} }
} }
@ -506,6 +500,16 @@ where
Endpoint::NestedRouter(router) => router.call_with_state(req, state), Endpoint::NestedRouter(router) => router.call_with_state(req, state),
} }
} }
fn next_route_id(&mut self) -> RouteId {
let next_id = self
.prev_route_id
.0
.checked_add(1)
.expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
self.prev_route_id = RouteId(next_id);
self.prev_route_id
}
} }
impl<B> Router<(), B> impl<B> Router<(), B>