Rewrite how state is passed from Router to MethodRouter

This commit is contained in:
Jonas Platte 2022-10-09 21:33:40 +02:00 committed by GitHub
parent 7cbacd1433
commit a2ab338e68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 264 additions and 242 deletions

View file

@ -1,85 +0,0 @@
use super::Handler;
use crate::response::Response;
use http::Request;
use std::{
convert::Infallible,
fmt,
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
};
use tower_service::Service;
pub(crate) struct IntoServiceStateInExtension<H, T, S, B> {
handler: H,
_marker: PhantomData<fn() -> (T, S, B)>,
}
#[test]
fn traits() {
use crate::test_helpers::*;
assert_send::<IntoServiceStateInExtension<(), NotSendSync, (), NotSendSync>>();
assert_sync::<IntoServiceStateInExtension<(), NotSendSync, (), NotSendSync>>();
}
impl<H, T, S, B> IntoServiceStateInExtension<H, T, S, B> {
pub(crate) fn new(handler: H) -> Self {
Self {
handler,
_marker: PhantomData,
}
}
}
impl<H, T, S, B> fmt::Debug for IntoServiceStateInExtension<H, T, S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IntoServiceStateInExtension")
.finish_non_exhaustive()
}
}
impl<H, T, S, B> Clone for IntoServiceStateInExtension<H, T, S, B>
where
H: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
_marker: PhantomData,
}
}
}
impl<H, T, S, B> Service<Request<B>> for IntoServiceStateInExtension<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = super::future::IntoServiceFuture<H::Future>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// `IntoServiceStateInExtension` can only be constructed from async functions which are always ready, or
// from `Layered` which bufferes in `<Layered as Handler>::call` and is therefore
// also always ready.
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
use futures_util::future::FutureExt;
let state = req
.extensions_mut()
.remove::<Arc<S>>()
.expect("state extension missing. This is a bug in axum, please file an issue");
let handler = self.handler.clone();
let future = Handler::call(handler, req, state);
let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future)
}
}

View file

@ -51,13 +51,10 @@ use tower_service::Service;
mod boxed;
pub mod future;
mod into_service_state_in_extension;
mod service;
pub(crate) use self::boxed::BoxedHandler;
pub use self::service::HandlerService;
pub(crate) use self::{
boxed::BoxedHandler, into_service_state_in_extension::IntoServiceStateInExtension,
};
/// Trait for async functions that can be used to handle requests.
///

View file

@ -6,10 +6,11 @@ use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
body::{Body, Bytes, HttpBody},
error_handling::{HandleError, HandleErrorLayer},
handler::{Handler, IntoServiceStateInExtension},
handler::{BoxedHandler, Handler},
http::{Method, Request, StatusCode},
response::Response,
routing::{future::RouteFuture, Fallback, MethodFilter, Route},
util::try_downcast,
};
use axum_core::response::IntoResponse;
use bytes::BytesMut;
@ -511,19 +512,19 @@ where
/// {}
/// ```
pub struct MethodRouter<S = (), B = Body, E = Infallible> {
get: Option<Route<B, E>>,
head: Option<Route<B, E>>,
delete: Option<Route<B, E>>,
options: Option<Route<B, E>>,
patch: Option<Route<B, E>>,
post: Option<Route<B, E>>,
put: Option<Route<B, E>>,
trace: Option<Route<B, E>>,
get: MethodEndpoint<S, B, E>,
head: MethodEndpoint<S, B, E>,
delete: MethodEndpoint<S, B, E>,
options: MethodEndpoint<S, B, E>,
patch: MethodEndpoint<S, B, E>,
post: MethodEndpoint<S, B, E>,
put: MethodEndpoint<S, B, E>,
trace: MethodEndpoint<S, B, E>,
fallback: Fallback<S, B, E>,
allow_header: AllowHeader,
}
#[derive(Clone)]
#[derive(Clone, Debug)]
enum AllowHeader {
/// No `Allow` header value has been built-up yet. This is the default state
None,
@ -561,6 +562,7 @@ impl<S, B, E> fmt::Debug for MethodRouter<S, B, E> {
.field("put", &self.put)
.field("trace", &self.trace)
.field("fallback", &self.fallback)
.field("allow_header", &self.allow_header)
.finish()
}
}
@ -599,7 +601,10 @@ where
T: 'static,
S: Send + Sync + 'static,
{
self.on_service(filter, IntoServiceStateInExtension::new(handler))
self.on_endpoint(
filter,
MethodEndpoint::BoxedHandler(BoxedHandler::new(handler)),
)
}
chained_handler_fn!(delete, DELETE);
@ -612,13 +617,14 @@ where
chained_handler_fn!(trace, TRACE);
/// Add a fallback [`Handler`] to the router.
pub fn fallback<H, T>(self, handler: H) -> Self
pub fn fallback<H, T>(mut self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: 'static,
S: Send + Sync + 'static,
{
self.fallback_service(IntoServiceStateInExtension::new(handler))
self.fallback = Fallback::BoxedHandler(BoxedHandler::new(handler));
self
}
}
@ -708,14 +714,14 @@ where
}));
Self {
get: None,
head: None,
delete: None,
options: None,
patch: None,
post: None,
put: None,
trace: None,
get: MethodEndpoint::None,
head: MethodEndpoint::None,
delete: MethodEndpoint::None,
options: MethodEndpoint::None,
patch: MethodEndpoint::None,
post: MethodEndpoint::None,
put: MethodEndpoint::None,
trace: MethodEndpoint::None,
allow_header: AllowHeader::None,
fallback: Fallback::Default(fallback),
}
@ -724,17 +730,25 @@ where
/// Provide the state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn with_state(self, state: S) -> WithState<S, B, E> {
pub fn with_state(self, state: S) -> WithState<B, E> {
self.with_state_arc(Arc::new(state))
}
/// Provide the [`Arc`]'ed state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn with_state_arc(self, state: Arc<S>) -> WithState<S, B, E> {
pub fn with_state_arc(self, state: Arc<S>) -> WithState<B, E> {
WithState {
method_router: self,
state,
get: self.get.into_route(&state),
head: self.head.into_route(&state),
delete: self.delete.into_route(&state),
options: self.options.into_route(&state),
patch: self.patch.into_route(&state),
post: self.post.into_route(&state),
put: self.put.into_route(&state),
trace: self.trace.into_route(&state),
fallback: self.fallback.into_route(&state),
allow_header: self.allow_header,
}
}
@ -745,14 +759,14 @@ where
S2: 'static,
{
MethodRouter {
get: self.get,
head: self.head,
delete: self.delete,
options: self.options,
patch: self.patch,
post: self.post,
put: self.put,
trace: self.trace,
get: self.get.map_state(state),
head: self.head.map_state(state),
delete: self.delete.map_state(state),
options: self.options.map_state(state),
patch: self.patch.map_state(state),
post: self.post.map_state(state),
put: self.put.map_state(state),
trace: self.trace.map_state(state),
fallback: self.fallback.map_state(state),
allow_header: self.allow_header,
}
@ -765,14 +779,14 @@ where
S2: 'static,
{
Some(MethodRouter {
get: self.get,
head: self.head,
delete: self.delete,
options: self.options,
patch: self.patch,
post: self.post,
put: self.put,
trace: self.trace,
get: self.get.downcast_state()?,
head: self.head.downcast_state()?,
delete: self.delete.downcast_state()?,
options: self.options.downcast_state()?,
patch: self.patch.downcast_state()?,
post: self.post.downcast_state()?,
put: self.put.downcast_state()?,
trace: self.trace.downcast_state()?,
fallback: self.fallback.downcast_state()?,
allow_header: self.allow_header,
})
@ -804,112 +818,118 @@ where
/// # };
/// ```
#[track_caller]
pub fn on_service<T>(mut self, filter: MethodFilter, svc: T) -> Self
pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
where
T: Service<Request<B>, Error = E> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
{
// written using an inner function to generate less IR
self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
}
#[track_caller]
fn set_service<T>(
fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, B, E>) -> Self {
// written as a separate function to generate less IR
#[track_caller]
fn set_endpoint<S, B, E>(
method_name: &str,
out: &mut Option<T>,
svc: &T,
svc_filter: MethodFilter,
out: &mut MethodEndpoint<S, B, E>,
endpoint: &MethodEndpoint<S, B, E>,
endpoint_filter: MethodFilter,
filter: MethodFilter,
allow_header: &mut AllowHeader,
methods: &[&'static str],
) where
T: Clone,
MethodEndpoint<S, B, E>: Clone,
{
if svc_filter.contains(filter) {
if endpoint_filter.contains(filter) {
if out.is_some() {
panic!("Overlapping method route. Cannot add two method routes that both handle `{}`", method_name)
panic!(
"Overlapping method route. Cannot add two method routes that both handle \
`{method_name}`",
)
}
*out = Some(svc.clone());
*out = endpoint.clone();
for method in methods {
append_allow_header(allow_header, method);
}
}
}
let svc = Route::new(svc);
set_service(
set_endpoint(
"GET",
&mut self.get,
&svc,
&endpoint,
filter,
MethodFilter::GET,
&mut self.allow_header,
&["GET", "HEAD"],
);
set_service(
set_endpoint(
"HEAD",
&mut self.head,
&svc,
&endpoint,
filter,
MethodFilter::HEAD,
&mut self.allow_header,
&["HEAD"],
);
set_service(
set_endpoint(
"TRACE",
&mut self.trace,
&svc,
&endpoint,
filter,
MethodFilter::TRACE,
&mut self.allow_header,
&["TRACE"],
);
set_service(
set_endpoint(
"PUT",
&mut self.put,
&svc,
&endpoint,
filter,
MethodFilter::PUT,
&mut self.allow_header,
&["PUT"],
);
set_service(
set_endpoint(
"POST",
&mut self.post,
&svc,
&endpoint,
filter,
MethodFilter::POST,
&mut self.allow_header,
&["POST"],
);
set_service(
set_endpoint(
"PATCH",
&mut self.patch,
&svc,
&endpoint,
filter,
MethodFilter::PATCH,
&mut self.allow_header,
&["PATCH"],
);
set_service(
set_endpoint(
"OPTIONS",
&mut self.options,
&svc,
&endpoint,
filter,
MethodFilter::OPTIONS,
&mut self.allow_header,
&["OPTIONS"],
);
set_service(
set_endpoint(
"DELETE",
&mut self.delete,
&svc,
&endpoint,
filter,
MethodFilter::DELETE,
&mut self.allow_header,
@ -976,10 +996,12 @@ where
#[track_caller]
pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, B, E>
where
L: Layer<Route<B, E>>,
L: Layer<Route<B, E>> + Clone + Send + 'static,
L::Service: Service<Request<B>, Error = E> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<B>>>::Future: Send + 'static,
E: 'static,
S: 'static,
{
if self.get.is_none()
&& self.head.is_none()
@ -996,19 +1018,19 @@ where
);
}
let layer_fn = |svc| {
let layer_fn = move |svc| {
let svc = layer.layer(svc);
let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
Route::new(svc)
};
self.get = self.get.map(layer_fn);
self.head = self.head.map(layer_fn);
self.delete = self.delete.map(layer_fn);
self.options = self.options.map(layer_fn);
self.patch = self.patch.map(layer_fn);
self.post = self.post.map(layer_fn);
self.put = self.put.map(layer_fn);
self.get = self.get.map(layer_fn.clone());
self.head = self.head.map(layer_fn.clone());
self.delete = self.delete.map(layer_fn.clone());
self.options = self.options.map(layer_fn.clone());
self.patch = self.patch.map(layer_fn.clone());
self.post = self.post.map(layer_fn.clone());
self.put = self.put.map(layer_fn.clone());
self.trace = self.trace.map(layer_fn);
self
@ -1022,24 +1044,27 @@ where
) -> Self {
// written using inner functions to generate less IR
#[track_caller]
fn merge_inner<T>(
fn merge_inner<S, B, E>(
path: Option<&str>,
name: &str,
first: Option<T>,
second: Option<T>,
) -> Option<T> {
first: MethodEndpoint<S, B, E>,
second: MethodEndpoint<S, B, E>,
) -> MethodEndpoint<S, B, E> {
match (first, second) {
(Some(_), Some(_)) => {
(MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None,
(pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick,
_ => {
if let Some(path) = path {
panic!(
"Overlapping method route. Handler for `{name} {path}` already exists"
)
);
} else {
panic!("Overlapping method route. Cannot merge two method routes that both define `{name}`")
panic!(
"Overlapping method route. Cannot merge two method routes that both \
define `{name}`"
);
}
}
(Some(svc), None) | (None, Some(svc)) => Some(svc),
(None, None) => None,
}
}
@ -1155,6 +1180,92 @@ where
}
}
enum MethodEndpoint<S, B, E> {
None,
Route(Route<B, E>),
BoxedHandler(BoxedHandler<S, B, E>),
}
impl<S, B, E> MethodEndpoint<S, B, E> {
fn is_some(&self) -> bool {
matches!(self, Self::Route(_) | Self::BoxedHandler(_))
}
fn is_none(&self) -> bool {
matches!(self, Self::None)
}
fn map<F, B2, E2>(self, f: F) -> MethodEndpoint<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
match self {
Self::None => MethodEndpoint::None,
Self::Route(route) => MethodEndpoint::Route(f(route)),
Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
}
}
fn map_state<S2>(self, state: &Arc<S>) -> MethodEndpoint<S2, B, E> {
match self {
Self::None => MethodEndpoint::None,
Self::Route(route) => MethodEndpoint::Route(route),
Self::BoxedHandler(handler) => MethodEndpoint::Route(handler.into_route(state.clone())),
}
}
fn downcast_state<S2>(self) -> Option<MethodEndpoint<S2, B, E>>
where
S: 'static,
B: 'static,
E: 'static,
S2: 'static,
{
match self {
Self::None => Some(MethodEndpoint::None),
Self::Route(route) => Some(MethodEndpoint::Route(route)),
Self::BoxedHandler(handler) => {
try_downcast::<BoxedHandler<S2, B, E>, BoxedHandler<S, B, E>>(handler)
.map(MethodEndpoint::BoxedHandler)
.ok()
}
}
}
fn into_route(self, state: &Arc<S>) -> Option<Route<B, E>> {
match self {
Self::None => None,
Self::Route(route) => Some(route),
Self::BoxedHandler(handler) => Some(handler.into_route(state.clone())),
}
}
}
impl<S, B, E> Clone for MethodEndpoint<S, B, E> {
fn clone(&self) -> Self {
match self {
Self::None => Self::None,
Self::Route(inner) => Self::Route(inner.clone()),
Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
}
}
}
impl<S, B, E> fmt::Debug for MethodEndpoint<S, B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => f.debug_tuple("None").finish(),
Self::Route(inner) => inner.fmt(f),
Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
}
}
}
/// A [`MethodRouter`] which has access to some state.
///
/// Implements [`Service`].
@ -1162,17 +1273,20 @@ where
/// The state can be extracted with [`State`](crate::extract::State).
///
/// Created with [`MethodRouter::with_state`]
pub struct WithState<S, B, E> {
method_router: MethodRouter<S, B, E>,
state: Arc<S>,
}
impl<S, B, E> WithState<S, B, E> {
/// Get a reference to the state.
pub fn state(&self) -> &S {
&self.state
pub struct WithState<B, E> {
get: Option<Route<B, E>>,
head: Option<Route<B, E>>,
delete: Option<Route<B, E>>,
options: Option<Route<B, E>>,
patch: Option<Route<B, E>>,
post: Option<Route<B, E>>,
put: Option<Route<B, E>>,
trace: Option<Route<B, E>>,
fallback: Route<B, E>,
allow_header: AllowHeader,
}
impl<B, E> WithState<B, E> {
/// Convert the handler into a [`MakeService`].
///
/// See [`MethodRouter::into_make_service`] for more details.
@ -1194,31 +1308,43 @@ impl<S, B, E> WithState<S, B, E> {
}
}
impl<S, B, E> Clone for WithState<S, B, E> {
impl<B, E> Clone for WithState<B, E> {
fn clone(&self) -> Self {
Self {
method_router: self.method_router.clone(),
state: Arc::clone(&self.state),
get: self.get.clone(),
head: self.head.clone(),
delete: self.delete.clone(),
options: self.options.clone(),
patch: self.patch.clone(),
post: self.post.clone(),
put: self.put.clone(),
trace: self.trace.clone(),
fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
}
}
}
impl<S, B, E> fmt::Debug for WithState<S, B, E>
where
S: fmt::Debug,
{
impl<B, E> fmt::Debug for WithState<B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WithState")
.field("method_router", &self.method_router)
.field("state", &self.state)
.field("get", &self.get)
.field("head", &self.head)
.field("delete", &self.delete)
.field("options", &self.options)
.field("patch", &self.patch)
.field("post", &self.post)
.field("put", &self.put)
.field("trace", &self.trace)
.field("fallback", &self.fallback)
.field("allow_header", &self.allow_header)
.finish()
}
}
impl<S, B, E> Service<Request<B>> for WithState<S, B, E>
impl<B, E> Service<Request<B>> for WithState<B, E>
where
B: HttpBody + Send,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = E;
@ -1229,7 +1355,7 @@ where
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
fn call(&mut self, req: Request<B>) -> Self::Future {
macro_rules! call {
(
$req:expr,
@ -1250,9 +1376,6 @@ where
// written with a pattern match like this to ensure we call all routes
let Self {
state,
method_router:
MethodRouter {
get,
head,
delete,
@ -1263,11 +1386,8 @@ where
trace,
fallback,
allow_header,
},
} = self;
req.extensions_mut().insert(Arc::clone(state));
call!(req, method, HEAD, head);
call!(req, method, HEAD, get);
call!(req, method, GET, get);
@ -1278,18 +1398,7 @@ where
call!(req, method, DELETE, delete);
call!(req, method, TRACE, trace);
let future = match fallback {
Fallback::Default(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
.strip_body(method == Method::HEAD),
Fallback::Service(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
.strip_body(method == Method::HEAD),
Fallback::BoxedHandler(fallback) => RouteFuture::from_future(
fallback
.clone()
.into_route(Arc::clone(state))
.oneshot_inner(req),
),
};
let future = RouteFuture::from_future(fallback.oneshot_inner(req));
match allow_header {
AllowHeader::None => future.allow_header(Bytes::new()),
@ -1302,7 +1411,10 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{body::Body, error_handling::HandleErrorLayer, extract::State};
use crate::{
body::Body, error_handling::HandleErrorLayer, extract::State,
handler::HandlerWithoutStateExt,
};
use axum_core::response::IntoResponse;
use http::{header::ALLOW, HeaderMap};
use std::time::Duration;
@ -1517,8 +1629,7 @@ mod tests {
expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
)]
async fn service_overlaps() {
let _: MethodRouter<()> = post_service(IntoServiceStateInExtension::<_, _, (), _>::new(ok))
.post_service(IntoServiceStateInExtension::<_, _, (), _>::new(ok));
let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
}
#[tokio::test]

View file

@ -7,7 +7,6 @@ use crate::{
body::{Body, HttpBody},
handler::{BoxedHandler, Handler},
util::try_downcast,
Extension,
};
use axum_core::response::IntoResponse;
use http::Request;
@ -350,9 +349,7 @@ where
// other has its state set
Some(state) => {
let fallback = fallback.map_state(&state);
cast_method_router_closure_slot = move |r: MethodRouter<_, _>| {
r.layer(Extension(Arc::clone(&state))).map_state(&state)
};
cast_method_router_closure_slot = move |r: MethodRouter<_, _>| r.map_state(&state);
let cast_method_router = &cast_method_router_closure_slot
as &dyn Fn(MethodRouter<_, _>) -> MethodRouter<_, _>;
@ -637,6 +634,14 @@ impl<S, B, E> Fallback<S, B, E> {
_ => None,
}
}
fn into_route(self, state: &Arc<S>) -> Route<B, E> {
match self {
Self::Default(route) => route,
Self::Service(route) => route,
Self::BoxedHandler(handler) => handler.into_route(state.clone()),
}
}
}
impl<S, B, E> Clone for Fallback<S, B, E> {
@ -677,6 +682,7 @@ impl<S, B, E> Fallback<S, B, E> {
}
}
#[allow(clippy::large_enum_variant)] // This type is only used at init time, probably fine
enum Endpoint<S, B> {
MethodRouter(MethodRouter<S, B>),
Route(Route<B>),
@ -691,10 +697,7 @@ impl<S, B> Clone for Endpoint<S, B> {
}
}
impl<S, B> fmt::Debug for Endpoint<S, B>
where
S: fmt::Debug,
{
impl<S, B> fmt::Debug for Endpoint<S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MethodRouter(inner) => inner.fmt(f),

View file

@ -10,7 +10,7 @@ use matchit::MatchError;
use tower::Service;
use super::{
future::RouteFuture, url_params, Endpoint, Fallback, Node, Route, RouteId, Router,
future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router,
NEST_TAIL_PARAM_CAPTURE,
};
use crate::{
@ -57,11 +57,7 @@ where
Self {
routes,
node: router.node,
fallback: match router.fallback {
Fallback::Default(route) => route,
Fallback::Service(route) => route,
Fallback::BoxedHandler(handler) => handler.into_route(state),
},
fallback: router.fallback.into_route(&state),
}
}