Rewrite how state is passed from Router to MethodRouter

This commit is contained in:
Jonas Platte 2022-10-09 21:33:40 +02:00 committed by GitHub
parent 7cbacd1433
commit a2ab338e68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 264 additions and 242 deletions

View file

@ -1,85 +0,0 @@
use super::Handler;
use crate::response::Response;
use http::Request;
use std::{
convert::Infallible,
fmt,
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
};
use tower_service::Service;
pub(crate) struct IntoServiceStateInExtension<H, T, S, B> {
handler: H,
_marker: PhantomData<fn() -> (T, S, B)>,
}
#[test]
fn traits() {
use crate::test_helpers::*;
assert_send::<IntoServiceStateInExtension<(), NotSendSync, (), NotSendSync>>();
assert_sync::<IntoServiceStateInExtension<(), NotSendSync, (), NotSendSync>>();
}
impl<H, T, S, B> IntoServiceStateInExtension<H, T, S, B> {
pub(crate) fn new(handler: H) -> Self {
Self {
handler,
_marker: PhantomData,
}
}
}
impl<H, T, S, B> fmt::Debug for IntoServiceStateInExtension<H, T, S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IntoServiceStateInExtension")
.finish_non_exhaustive()
}
}
impl<H, T, S, B> Clone for IntoServiceStateInExtension<H, T, S, B>
where
H: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
_marker: PhantomData,
}
}
}
impl<H, T, S, B> Service<Request<B>> for IntoServiceStateInExtension<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = super::future::IntoServiceFuture<H::Future>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// `IntoServiceStateInExtension` can only be constructed from async functions which are always ready, or
// from `Layered` which bufferes in `<Layered as Handler>::call` and is therefore
// also always ready.
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
use futures_util::future::FutureExt;
let state = req
.extensions_mut()
.remove::<Arc<S>>()
.expect("state extension missing. This is a bug in axum, please file an issue");
let handler = self.handler.clone();
let future = Handler::call(handler, req, state);
let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future)
}
}

View file

@ -51,13 +51,10 @@ use tower_service::Service;
mod boxed; mod boxed;
pub mod future; pub mod future;
mod into_service_state_in_extension;
mod service; mod service;
pub(crate) use self::boxed::BoxedHandler;
pub use self::service::HandlerService; pub use self::service::HandlerService;
pub(crate) use self::{
boxed::BoxedHandler, into_service_state_in_extension::IntoServiceStateInExtension,
};
/// Trait for async functions that can be used to handle requests. /// Trait for async functions that can be used to handle requests.
/// ///

View file

@ -6,10 +6,11 @@ use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{ use crate::{
body::{Body, Bytes, HttpBody}, body::{Body, Bytes, HttpBody},
error_handling::{HandleError, HandleErrorLayer}, error_handling::{HandleError, HandleErrorLayer},
handler::{Handler, IntoServiceStateInExtension}, handler::{BoxedHandler, Handler},
http::{Method, Request, StatusCode}, http::{Method, Request, StatusCode},
response::Response, response::Response,
routing::{future::RouteFuture, Fallback, MethodFilter, Route}, routing::{future::RouteFuture, Fallback, MethodFilter, Route},
util::try_downcast,
}; };
use axum_core::response::IntoResponse; use axum_core::response::IntoResponse;
use bytes::BytesMut; use bytes::BytesMut;
@ -511,19 +512,19 @@ where
/// {} /// {}
/// ``` /// ```
pub struct MethodRouter<S = (), B = Body, E = Infallible> { pub struct MethodRouter<S = (), B = Body, E = Infallible> {
get: Option<Route<B, E>>, get: MethodEndpoint<S, B, E>,
head: Option<Route<B, E>>, head: MethodEndpoint<S, B, E>,
delete: Option<Route<B, E>>, delete: MethodEndpoint<S, B, E>,
options: Option<Route<B, E>>, options: MethodEndpoint<S, B, E>,
patch: Option<Route<B, E>>, patch: MethodEndpoint<S, B, E>,
post: Option<Route<B, E>>, post: MethodEndpoint<S, B, E>,
put: Option<Route<B, E>>, put: MethodEndpoint<S, B, E>,
trace: Option<Route<B, E>>, trace: MethodEndpoint<S, B, E>,
fallback: Fallback<S, B, E>, fallback: Fallback<S, B, E>,
allow_header: AllowHeader, allow_header: AllowHeader,
} }
#[derive(Clone)] #[derive(Clone, Debug)]
enum AllowHeader { enum AllowHeader {
/// No `Allow` header value has been built-up yet. This is the default state /// No `Allow` header value has been built-up yet. This is the default state
None, None,
@ -561,6 +562,7 @@ impl<S, B, E> fmt::Debug for MethodRouter<S, B, E> {
.field("put", &self.put) .field("put", &self.put)
.field("trace", &self.trace) .field("trace", &self.trace)
.field("fallback", &self.fallback) .field("fallback", &self.fallback)
.field("allow_header", &self.allow_header)
.finish() .finish()
} }
} }
@ -599,7 +601,10 @@ where
T: 'static, T: 'static,
S: Send + Sync + 'static, S: Send + Sync + 'static,
{ {
self.on_service(filter, IntoServiceStateInExtension::new(handler)) self.on_endpoint(
filter,
MethodEndpoint::BoxedHandler(BoxedHandler::new(handler)),
)
} }
chained_handler_fn!(delete, DELETE); chained_handler_fn!(delete, DELETE);
@ -612,13 +617,14 @@ where
chained_handler_fn!(trace, TRACE); chained_handler_fn!(trace, TRACE);
/// Add a fallback [`Handler`] to the router. /// Add a fallback [`Handler`] to the router.
pub fn fallback<H, T>(self, handler: H) -> Self pub fn fallback<H, T>(mut self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: 'static, T: 'static,
S: Send + Sync + 'static, S: Send + Sync + 'static,
{ {
self.fallback_service(IntoServiceStateInExtension::new(handler)) self.fallback = Fallback::BoxedHandler(BoxedHandler::new(handler));
self
} }
} }
@ -708,14 +714,14 @@ where
})); }));
Self { Self {
get: None, get: MethodEndpoint::None,
head: None, head: MethodEndpoint::None,
delete: None, delete: MethodEndpoint::None,
options: None, options: MethodEndpoint::None,
patch: None, patch: MethodEndpoint::None,
post: None, post: MethodEndpoint::None,
put: None, put: MethodEndpoint::None,
trace: None, trace: MethodEndpoint::None,
allow_header: AllowHeader::None, allow_header: AllowHeader::None,
fallback: Fallback::Default(fallback), fallback: Fallback::Default(fallback),
} }
@ -724,17 +730,25 @@ where
/// Provide the state. /// Provide the state.
/// ///
/// See [`State`](crate::extract::State) for more details about accessing state. /// See [`State`](crate::extract::State) for more details about accessing state.
pub fn with_state(self, state: S) -> WithState<S, B, E> { pub fn with_state(self, state: S) -> WithState<B, E> {
self.with_state_arc(Arc::new(state)) self.with_state_arc(Arc::new(state))
} }
/// Provide the [`Arc`]'ed state. /// Provide the [`Arc`]'ed state.
/// ///
/// See [`State`](crate::extract::State) for more details about accessing state. /// See [`State`](crate::extract::State) for more details about accessing state.
pub fn with_state_arc(self, state: Arc<S>) -> WithState<S, B, E> { pub fn with_state_arc(self, state: Arc<S>) -> WithState<B, E> {
WithState { WithState {
method_router: self, get: self.get.into_route(&state),
state, head: self.head.into_route(&state),
delete: self.delete.into_route(&state),
options: self.options.into_route(&state),
patch: self.patch.into_route(&state),
post: self.post.into_route(&state),
put: self.put.into_route(&state),
trace: self.trace.into_route(&state),
fallback: self.fallback.into_route(&state),
allow_header: self.allow_header,
} }
} }
@ -745,14 +759,14 @@ where
S2: 'static, S2: 'static,
{ {
MethodRouter { MethodRouter {
get: self.get, get: self.get.map_state(state),
head: self.head, head: self.head.map_state(state),
delete: self.delete, delete: self.delete.map_state(state),
options: self.options, options: self.options.map_state(state),
patch: self.patch, patch: self.patch.map_state(state),
post: self.post, post: self.post.map_state(state),
put: self.put, put: self.put.map_state(state),
trace: self.trace, trace: self.trace.map_state(state),
fallback: self.fallback.map_state(state), fallback: self.fallback.map_state(state),
allow_header: self.allow_header, allow_header: self.allow_header,
} }
@ -765,14 +779,14 @@ where
S2: 'static, S2: 'static,
{ {
Some(MethodRouter { Some(MethodRouter {
get: self.get, get: self.get.downcast_state()?,
head: self.head, head: self.head.downcast_state()?,
delete: self.delete, delete: self.delete.downcast_state()?,
options: self.options, options: self.options.downcast_state()?,
patch: self.patch, patch: self.patch.downcast_state()?,
post: self.post, post: self.post.downcast_state()?,
put: self.put, put: self.put.downcast_state()?,
trace: self.trace, trace: self.trace.downcast_state()?,
fallback: self.fallback.downcast_state()?, fallback: self.fallback.downcast_state()?,
allow_header: self.allow_header, allow_header: self.allow_header,
}) })
@ -804,112 +818,118 @@ where
/// # }; /// # };
/// ``` /// ```
#[track_caller] #[track_caller]
pub fn on_service<T>(mut self, filter: MethodFilter, svc: T) -> Self pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
where where
T: Service<Request<B>, Error = E> + Clone + Send + 'static, T: Service<Request<B>, Error = E> + Clone + Send + 'static,
T::Response: IntoResponse + 'static, T::Response: IntoResponse + 'static,
T::Future: Send + 'static, T::Future: Send + 'static,
{ {
// written using an inner function to generate less IR self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
}
#[track_caller] #[track_caller]
fn set_service<T>( fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, B, E>) -> Self {
// written as a separate function to generate less IR
#[track_caller]
fn set_endpoint<S, B, E>(
method_name: &str, method_name: &str,
out: &mut Option<T>, out: &mut MethodEndpoint<S, B, E>,
svc: &T, endpoint: &MethodEndpoint<S, B, E>,
svc_filter: MethodFilter, endpoint_filter: MethodFilter,
filter: MethodFilter, filter: MethodFilter,
allow_header: &mut AllowHeader, allow_header: &mut AllowHeader,
methods: &[&'static str], methods: &[&'static str],
) where ) where
T: Clone, MethodEndpoint<S, B, E>: Clone,
{ {
if svc_filter.contains(filter) { if endpoint_filter.contains(filter) {
if out.is_some() { if out.is_some() {
panic!("Overlapping method route. Cannot add two method routes that both handle `{}`", method_name) panic!(
"Overlapping method route. Cannot add two method routes that both handle \
`{method_name}`",
)
} }
*out = Some(svc.clone()); *out = endpoint.clone();
for method in methods { for method in methods {
append_allow_header(allow_header, method); append_allow_header(allow_header, method);
} }
} }
} }
let svc = Route::new(svc); set_endpoint(
set_service(
"GET", "GET",
&mut self.get, &mut self.get,
&svc, &endpoint,
filter, filter,
MethodFilter::GET, MethodFilter::GET,
&mut self.allow_header, &mut self.allow_header,
&["GET", "HEAD"], &["GET", "HEAD"],
); );
set_service( set_endpoint(
"HEAD", "HEAD",
&mut self.head, &mut self.head,
&svc, &endpoint,
filter, filter,
MethodFilter::HEAD, MethodFilter::HEAD,
&mut self.allow_header, &mut self.allow_header,
&["HEAD"], &["HEAD"],
); );
set_service( set_endpoint(
"TRACE", "TRACE",
&mut self.trace, &mut self.trace,
&svc, &endpoint,
filter, filter,
MethodFilter::TRACE, MethodFilter::TRACE,
&mut self.allow_header, &mut self.allow_header,
&["TRACE"], &["TRACE"],
); );
set_service( set_endpoint(
"PUT", "PUT",
&mut self.put, &mut self.put,
&svc, &endpoint,
filter, filter,
MethodFilter::PUT, MethodFilter::PUT,
&mut self.allow_header, &mut self.allow_header,
&["PUT"], &["PUT"],
); );
set_service( set_endpoint(
"POST", "POST",
&mut self.post, &mut self.post,
&svc, &endpoint,
filter, filter,
MethodFilter::POST, MethodFilter::POST,
&mut self.allow_header, &mut self.allow_header,
&["POST"], &["POST"],
); );
set_service( set_endpoint(
"PATCH", "PATCH",
&mut self.patch, &mut self.patch,
&svc, &endpoint,
filter, filter,
MethodFilter::PATCH, MethodFilter::PATCH,
&mut self.allow_header, &mut self.allow_header,
&["PATCH"], &["PATCH"],
); );
set_service( set_endpoint(
"OPTIONS", "OPTIONS",
&mut self.options, &mut self.options,
&svc, &endpoint,
filter, filter,
MethodFilter::OPTIONS, MethodFilter::OPTIONS,
&mut self.allow_header, &mut self.allow_header,
&["OPTIONS"], &["OPTIONS"],
); );
set_service( set_endpoint(
"DELETE", "DELETE",
&mut self.delete, &mut self.delete,
&svc, &endpoint,
filter, filter,
MethodFilter::DELETE, MethodFilter::DELETE,
&mut self.allow_header, &mut self.allow_header,
@ -976,10 +996,12 @@ where
#[track_caller] #[track_caller]
pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, B, E> pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, B, E>
where where
L: Layer<Route<B, E>>, L: Layer<Route<B, E>> + Clone + Send + 'static,
L::Service: Service<Request<B>, Error = E> + Clone + Send + 'static, L::Service: Service<Request<B>, Error = E> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<B>>>::Future: Send + 'static, <L::Service as Service<Request<B>>>::Future: Send + 'static,
E: 'static,
S: 'static,
{ {
if self.get.is_none() if self.get.is_none()
&& self.head.is_none() && self.head.is_none()
@ -996,19 +1018,19 @@ where
); );
} }
let layer_fn = |svc| { let layer_fn = move |svc| {
let svc = layer.layer(svc); let svc = layer.layer(svc);
let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc); let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
Route::new(svc) Route::new(svc)
}; };
self.get = self.get.map(layer_fn); self.get = self.get.map(layer_fn.clone());
self.head = self.head.map(layer_fn); self.head = self.head.map(layer_fn.clone());
self.delete = self.delete.map(layer_fn); self.delete = self.delete.map(layer_fn.clone());
self.options = self.options.map(layer_fn); self.options = self.options.map(layer_fn.clone());
self.patch = self.patch.map(layer_fn); self.patch = self.patch.map(layer_fn.clone());
self.post = self.post.map(layer_fn); self.post = self.post.map(layer_fn.clone());
self.put = self.put.map(layer_fn); self.put = self.put.map(layer_fn.clone());
self.trace = self.trace.map(layer_fn); self.trace = self.trace.map(layer_fn);
self self
@ -1022,24 +1044,27 @@ where
) -> Self { ) -> Self {
// written using inner functions to generate less IR // written using inner functions to generate less IR
#[track_caller] #[track_caller]
fn merge_inner<T>( fn merge_inner<S, B, E>(
path: Option<&str>, path: Option<&str>,
name: &str, name: &str,
first: Option<T>, first: MethodEndpoint<S, B, E>,
second: Option<T>, second: MethodEndpoint<S, B, E>,
) -> Option<T> { ) -> MethodEndpoint<S, B, E> {
match (first, second) { match (first, second) {
(Some(_), Some(_)) => { (MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None,
(pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick,
_ => {
if let Some(path) = path { if let Some(path) = path {
panic!( panic!(
"Overlapping method route. Handler for `{name} {path}` already exists" "Overlapping method route. Handler for `{name} {path}` already exists"
) );
} else { } else {
panic!("Overlapping method route. Cannot merge two method routes that both define `{name}`") panic!(
"Overlapping method route. Cannot merge two method routes that both \
define `{name}`"
);
} }
} }
(Some(svc), None) | (None, Some(svc)) => Some(svc),
(None, None) => None,
} }
} }
@ -1155,6 +1180,92 @@ where
} }
} }
enum MethodEndpoint<S, B, E> {
None,
Route(Route<B, E>),
BoxedHandler(BoxedHandler<S, B, E>),
}
impl<S, B, E> MethodEndpoint<S, B, E> {
fn is_some(&self) -> bool {
matches!(self, Self::Route(_) | Self::BoxedHandler(_))
}
fn is_none(&self) -> bool {
matches!(self, Self::None)
}
fn map<F, B2, E2>(self, f: F) -> MethodEndpoint<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
match self {
Self::None => MethodEndpoint::None,
Self::Route(route) => MethodEndpoint::Route(f(route)),
Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
}
}
fn map_state<S2>(self, state: &Arc<S>) -> MethodEndpoint<S2, B, E> {
match self {
Self::None => MethodEndpoint::None,
Self::Route(route) => MethodEndpoint::Route(route),
Self::BoxedHandler(handler) => MethodEndpoint::Route(handler.into_route(state.clone())),
}
}
fn downcast_state<S2>(self) -> Option<MethodEndpoint<S2, B, E>>
where
S: 'static,
B: 'static,
E: 'static,
S2: 'static,
{
match self {
Self::None => Some(MethodEndpoint::None),
Self::Route(route) => Some(MethodEndpoint::Route(route)),
Self::BoxedHandler(handler) => {
try_downcast::<BoxedHandler<S2, B, E>, BoxedHandler<S, B, E>>(handler)
.map(MethodEndpoint::BoxedHandler)
.ok()
}
}
}
fn into_route(self, state: &Arc<S>) -> Option<Route<B, E>> {
match self {
Self::None => None,
Self::Route(route) => Some(route),
Self::BoxedHandler(handler) => Some(handler.into_route(state.clone())),
}
}
}
impl<S, B, E> Clone for MethodEndpoint<S, B, E> {
fn clone(&self) -> Self {
match self {
Self::None => Self::None,
Self::Route(inner) => Self::Route(inner.clone()),
Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
}
}
}
impl<S, B, E> fmt::Debug for MethodEndpoint<S, B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => f.debug_tuple("None").finish(),
Self::Route(inner) => inner.fmt(f),
Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
}
}
}
/// A [`MethodRouter`] which has access to some state. /// A [`MethodRouter`] which has access to some state.
/// ///
/// Implements [`Service`]. /// Implements [`Service`].
@ -1162,17 +1273,20 @@ where
/// The state can be extracted with [`State`](crate::extract::State). /// The state can be extracted with [`State`](crate::extract::State).
/// ///
/// Created with [`MethodRouter::with_state`] /// Created with [`MethodRouter::with_state`]
pub struct WithState<S, B, E> { pub struct WithState<B, E> {
method_router: MethodRouter<S, B, E>, get: Option<Route<B, E>>,
state: Arc<S>, head: Option<Route<B, E>>,
} delete: Option<Route<B, E>>,
options: Option<Route<B, E>>,
impl<S, B, E> WithState<S, B, E> { patch: Option<Route<B, E>>,
/// Get a reference to the state. post: Option<Route<B, E>>,
pub fn state(&self) -> &S { put: Option<Route<B, E>>,
&self.state trace: Option<Route<B, E>>,
fallback: Route<B, E>,
allow_header: AllowHeader,
} }
impl<B, E> WithState<B, E> {
/// Convert the handler into a [`MakeService`]. /// Convert the handler into a [`MakeService`].
/// ///
/// See [`MethodRouter::into_make_service`] for more details. /// See [`MethodRouter::into_make_service`] for more details.
@ -1194,31 +1308,43 @@ impl<S, B, E> WithState<S, B, E> {
} }
} }
impl<S, B, E> Clone for WithState<S, B, E> { impl<B, E> Clone for WithState<B, E> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
method_router: self.method_router.clone(), get: self.get.clone(),
state: Arc::clone(&self.state), head: self.head.clone(),
delete: self.delete.clone(),
options: self.options.clone(),
patch: self.patch.clone(),
post: self.post.clone(),
put: self.put.clone(),
trace: self.trace.clone(),
fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
} }
} }
} }
impl<S, B, E> fmt::Debug for WithState<S, B, E> impl<B, E> fmt::Debug for WithState<B, E> {
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WithState") f.debug_struct("WithState")
.field("method_router", &self.method_router) .field("get", &self.get)
.field("state", &self.state) .field("head", &self.head)
.field("delete", &self.delete)
.field("options", &self.options)
.field("patch", &self.patch)
.field("post", &self.post)
.field("put", &self.put)
.field("trace", &self.trace)
.field("fallback", &self.fallback)
.field("allow_header", &self.allow_header)
.finish() .finish()
} }
} }
impl<S, B, E> Service<Request<B>> for WithState<S, B, E> impl<B, E> Service<Request<B>> for WithState<B, E>
where where
B: HttpBody + Send, B: HttpBody + Send,
S: Send + Sync + 'static,
{ {
type Response = Response; type Response = Response;
type Error = E; type Error = E;
@ -1229,7 +1355,7 @@ where
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, mut req: Request<B>) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
macro_rules! call { macro_rules! call {
( (
$req:expr, $req:expr,
@ -1250,9 +1376,6 @@ where
// written with a pattern match like this to ensure we call all routes // written with a pattern match like this to ensure we call all routes
let Self { let Self {
state,
method_router:
MethodRouter {
get, get,
head, head,
delete, delete,
@ -1263,11 +1386,8 @@ where
trace, trace,
fallback, fallback,
allow_header, allow_header,
},
} = self; } = self;
req.extensions_mut().insert(Arc::clone(state));
call!(req, method, HEAD, head); call!(req, method, HEAD, head);
call!(req, method, HEAD, get); call!(req, method, HEAD, get);
call!(req, method, GET, get); call!(req, method, GET, get);
@ -1278,18 +1398,7 @@ where
call!(req, method, DELETE, delete); call!(req, method, DELETE, delete);
call!(req, method, TRACE, trace); call!(req, method, TRACE, trace);
let future = match fallback { let future = RouteFuture::from_future(fallback.oneshot_inner(req));
Fallback::Default(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
.strip_body(method == Method::HEAD),
Fallback::Service(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
.strip_body(method == Method::HEAD),
Fallback::BoxedHandler(fallback) => RouteFuture::from_future(
fallback
.clone()
.into_route(Arc::clone(state))
.oneshot_inner(req),
),
};
match allow_header { match allow_header {
AllowHeader::None => future.allow_header(Bytes::new()), AllowHeader::None => future.allow_header(Bytes::new()),
@ -1302,7 +1411,10 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::{body::Body, error_handling::HandleErrorLayer, extract::State}; use crate::{
body::Body, error_handling::HandleErrorLayer, extract::State,
handler::HandlerWithoutStateExt,
};
use axum_core::response::IntoResponse; use axum_core::response::IntoResponse;
use http::{header::ALLOW, HeaderMap}; use http::{header::ALLOW, HeaderMap};
use std::time::Duration; use std::time::Duration;
@ -1517,8 +1629,7 @@ mod tests {
expected = "Overlapping method route. Cannot add two method routes that both handle `POST`" expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
)] )]
async fn service_overlaps() { async fn service_overlaps() {
let _: MethodRouter<()> = post_service(IntoServiceStateInExtension::<_, _, (), _>::new(ok)) let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
.post_service(IntoServiceStateInExtension::<_, _, (), _>::new(ok));
} }
#[tokio::test] #[tokio::test]

View file

@ -7,7 +7,6 @@ use crate::{
body::{Body, HttpBody}, body::{Body, HttpBody},
handler::{BoxedHandler, Handler}, handler::{BoxedHandler, Handler},
util::try_downcast, util::try_downcast,
Extension,
}; };
use axum_core::response::IntoResponse; use axum_core::response::IntoResponse;
use http::Request; use http::Request;
@ -350,9 +349,7 @@ where
// other has its state set // other has its state set
Some(state) => { Some(state) => {
let fallback = fallback.map_state(&state); let fallback = fallback.map_state(&state);
cast_method_router_closure_slot = move |r: MethodRouter<_, _>| { cast_method_router_closure_slot = move |r: MethodRouter<_, _>| r.map_state(&state);
r.layer(Extension(Arc::clone(&state))).map_state(&state)
};
let cast_method_router = &cast_method_router_closure_slot let cast_method_router = &cast_method_router_closure_slot
as &dyn Fn(MethodRouter<_, _>) -> MethodRouter<_, _>; as &dyn Fn(MethodRouter<_, _>) -> MethodRouter<_, _>;
@ -637,6 +634,14 @@ impl<S, B, E> Fallback<S, B, E> {
_ => None, _ => None,
} }
} }
fn into_route(self, state: &Arc<S>) -> Route<B, E> {
match self {
Self::Default(route) => route,
Self::Service(route) => route,
Self::BoxedHandler(handler) => handler.into_route(state.clone()),
}
}
} }
impl<S, B, E> Clone for Fallback<S, B, E> { impl<S, B, E> Clone for Fallback<S, B, E> {
@ -677,6 +682,7 @@ impl<S, B, E> Fallback<S, B, E> {
} }
} }
#[allow(clippy::large_enum_variant)] // This type is only used at init time, probably fine
enum Endpoint<S, B> { enum Endpoint<S, B> {
MethodRouter(MethodRouter<S, B>), MethodRouter(MethodRouter<S, B>),
Route(Route<B>), Route(Route<B>),
@ -691,10 +697,7 @@ impl<S, B> Clone for Endpoint<S, B> {
} }
} }
impl<S, B> fmt::Debug for Endpoint<S, B> impl<S, B> fmt::Debug for Endpoint<S, B> {
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
Self::MethodRouter(inner) => inner.fmt(f), Self::MethodRouter(inner) => inner.fmt(f),

View file

@ -10,7 +10,7 @@ use matchit::MatchError;
use tower::Service; use tower::Service;
use super::{ use super::{
future::RouteFuture, url_params, Endpoint, Fallback, Node, Route, RouteId, Router, future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router,
NEST_TAIL_PARAM_CAPTURE, NEST_TAIL_PARAM_CAPTURE,
}; };
use crate::{ use crate::{
@ -57,11 +57,7 @@ where
Self { Self {
routes, routes,
node: router.node, node: router.node,
fallback: match router.fallback { fallback: router.fallback.into_route(&state),
Fallback::Default(route) => route,
Fallback::Service(route) => route,
Fallback::BoxedHandler(handler) => handler.into_route(state),
},
} }
} }