mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-01 08:56:15 +01:00
Fix routing issues when loading a Router
via a dynamic library (#1806)
This commit is contained in:
parent
6075be60ed
commit
5a58edac16
2 changed files with 22 additions and 15 deletions
|
@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
# Unreleased
|
# Unreleased
|
||||||
|
|
||||||
- **fixed:** Add `#[must_use]` to `WebSocketUpgrade::on_upgrade` ([#1801])
|
- **fixed:** Add `#[must_use]` to `WebSocketUpgrade::on_upgrade` ([#1801])
|
||||||
|
- **fixed:** Fix routing issues when loading a `Router` via a dynamic library ([#1806])
|
||||||
|
|
||||||
|
[#1806]: https://github.com/tokio-rs/axum/pull/1806
|
||||||
|
|
||||||
[#1801]: https://github.com/tokio-rs/axum/pull/1801
|
[#1801]: https://github.com/tokio-rs/axum/pull/1801
|
||||||
|
|
||||||
|
|
|
@ -47,24 +47,12 @@ pub use self::method_routing::{
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
pub(crate) struct RouteId(u32);
|
pub(crate) struct RouteId(u32);
|
||||||
|
|
||||||
impl RouteId {
|
|
||||||
fn next() -> Self {
|
|
||||||
use std::sync::atomic::{AtomicU32, Ordering};
|
|
||||||
// `AtomicU64` isn't supported on all platforms
|
|
||||||
static ID: AtomicU32 = AtomicU32::new(0);
|
|
||||||
let id = ID.fetch_add(1, Ordering::Relaxed);
|
|
||||||
if id == u32::MAX {
|
|
||||||
panic!("Over `u32::MAX` routes created. If you need this, please file an issue.");
|
|
||||||
}
|
|
||||||
Self(id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The router type for composing handlers and services.
|
/// The router type for composing handlers and services.
|
||||||
pub struct Router<S = (), B = Body> {
|
pub struct Router<S = (), B = Body> {
|
||||||
routes: HashMap<RouteId, Endpoint<S, B>>,
|
routes: HashMap<RouteId, Endpoint<S, B>>,
|
||||||
node: Arc<Node>,
|
node: Arc<Node>,
|
||||||
fallback: Fallback<S, B>,
|
fallback: Fallback<S, B>,
|
||||||
|
prev_route_id: RouteId,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, B> Clone for Router<S, B> {
|
impl<S, B> Clone for Router<S, B> {
|
||||||
|
@ -73,6 +61,7 @@ impl<S, B> Clone for Router<S, B> {
|
||||||
routes: self.routes.clone(),
|
routes: self.routes.clone(),
|
||||||
node: Arc::clone(&self.node),
|
node: Arc::clone(&self.node),
|
||||||
fallback: self.fallback.clone(),
|
fallback: self.fallback.clone(),
|
||||||
|
prev_route_id: self.prev_route_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -117,6 +106,7 @@ where
|
||||||
routes: Default::default(),
|
routes: Default::default(),
|
||||||
node: Default::default(),
|
node: Default::default(),
|
||||||
fallback: Fallback::Default(Route::new(NotFound)),
|
fallback: Fallback::Default(Route::new(NotFound)),
|
||||||
|
prev_route_id: RouteId(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,7 +124,7 @@ where
|
||||||
|
|
||||||
validate_path(path);
|
validate_path(path);
|
||||||
|
|
||||||
let id = RouteId::next();
|
let id = self.next_route_id();
|
||||||
|
|
||||||
let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
|
let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
|
||||||
.node
|
.node
|
||||||
|
@ -189,7 +179,7 @@ where
|
||||||
panic!("Paths must start with a `/`");
|
panic!("Paths must start with a `/`");
|
||||||
}
|
}
|
||||||
|
|
||||||
let id = RouteId::next();
|
let id = self.next_route_id();
|
||||||
self.set_node(path, id);
|
self.set_node(path, id);
|
||||||
self.routes.insert(id, endpoint);
|
self.routes.insert(id, endpoint);
|
||||||
self
|
self
|
||||||
|
@ -286,6 +276,7 @@ where
|
||||||
routes,
|
routes,
|
||||||
node,
|
node,
|
||||||
fallback,
|
fallback,
|
||||||
|
prev_route_id: _,
|
||||||
} = other.into();
|
} = other.into();
|
||||||
|
|
||||||
for (id, route) in routes {
|
for (id, route) in routes {
|
||||||
|
@ -335,6 +326,7 @@ where
|
||||||
routes,
|
routes,
|
||||||
node: self.node,
|
node: self.node,
|
||||||
fallback,
|
fallback,
|
||||||
|
prev_route_id: self.prev_route_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -368,6 +360,7 @@ where
|
||||||
routes,
|
routes,
|
||||||
node: self.node,
|
node: self.node,
|
||||||
fallback: self.fallback,
|
fallback: self.fallback,
|
||||||
|
prev_route_id: self.prev_route_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -419,6 +412,7 @@ where
|
||||||
routes,
|
routes,
|
||||||
node: self.node,
|
node: self.node,
|
||||||
fallback,
|
fallback,
|
||||||
|
prev_route_id: self.prev_route_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -506,6 +500,16 @@ where
|
||||||
Endpoint::NestedRouter(router) => router.call_with_state(req, state),
|
Endpoint::NestedRouter(router) => router.call_with_state(req, state),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn next_route_id(&mut self) -> RouteId {
|
||||||
|
let next_id = self
|
||||||
|
.prev_route_id
|
||||||
|
.0
|
||||||
|
.checked_add(1)
|
||||||
|
.expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
|
||||||
|
self.prev_route_id = RouteId(next_id);
|
||||||
|
self.prev_route_id
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B> Router<(), B>
|
impl<B> Router<(), B>
|
||||||
|
|
Loading…
Reference in a new issue