Improve build times by generating less IR (#1192)

* example

* `MethodRouter::merge`

* `set_content_length` and `set_allow_header`

* `MethodRouter::on_service_boxed_response_body`

* `Router::route`

* `MethodRouter::merge` again

* `MethodRouter::on_service_boxed_response_body`

* `Router::call_route`

* `MethodRouter::{layer, route_layer}`

* revert example

* fix test

* move function to method on `AllowHeader`
This commit is contained in:
David Pedersen 2022-07-25 20:06:37 +02:00 committed by GitHub
parent 1ace8554ce
commit 234c8ccb13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 231 additions and 220 deletions

View file

@ -13,10 +13,9 @@ use bytes::BytesMut;
use std::{ use std::{
convert::Infallible, convert::Infallible,
fmt, fmt,
marker::PhantomData,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tower::{service_fn, util::MapResponseLayer, ServiceBuilder}; use tower::{service_fn, util::MapResponseLayer};
use tower_layer::Layer; use tower_layer::Layer;
use tower_service::Service; use tower_service::Service;
@ -482,7 +481,6 @@ pub struct MethodRouter<B = Body, E = Infallible> {
trace: Option<Route<B, E>>, trace: Option<Route<B, E>>,
fallback: Fallback<B, E>, fallback: Fallback<B, E>,
allow_header: AllowHeader, allow_header: AllowHeader,
_request_body: PhantomData<fn() -> (B, E)>,
} }
#[derive(Clone)] #[derive(Clone)]
@ -495,6 +493,22 @@ enum AllowHeader {
Bytes(BytesMut), Bytes(BytesMut),
} }
impl AllowHeader {
fn merge(self, other: Self) -> Self {
match (self, other) {
(AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
(AllowHeader::None, AllowHeader::None) => AllowHeader::None,
(AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
a.extend_from_slice(b",");
a.extend_from_slice(&b);
AllowHeader::Bytes(a)
}
}
}
}
impl<B, E> fmt::Debug for MethodRouter<B, E> { impl<B, E> fmt::Debug for MethodRouter<B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MethodRouter") f.debug_struct("MethodRouter")
@ -532,7 +546,6 @@ impl<B, E> MethodRouter<B, E> {
trace: None, trace: None,
allow_header: AllowHeader::None, allow_header: AllowHeader::None,
fallback: Fallback::Default(fallback), fallback: Fallback::Default(fallback),
_request_body: PhantomData,
} }
} }
} }
@ -723,12 +736,11 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
{ {
let layer = ServiceBuilder::new() let layer_fn = |svc| {
.layer_fn(Route::new) let svc = layer.layer(svc);
.layer(MapResponseLayer::new(IntoResponse::into_response)) let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
.layer(layer) Route::new(svc)
.into_inner(); };
let layer_fn = |s| layer.layer(s);
MethodRouter { MethodRouter {
get: self.get.map(layer_fn), get: self.get.map(layer_fn),
@ -741,128 +753,77 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
trace: self.trace.map(layer_fn), trace: self.trace.map(layer_fn),
fallback: self.fallback.map(layer_fn), fallback: self.fallback.map(layer_fn),
allow_header: self.allow_header, allow_header: self.allow_header,
_request_body: PhantomData,
} }
} }
#[doc = include_str!("../docs/method_routing/route_layer.md")] #[doc = include_str!("../docs/method_routing/route_layer.md")]
pub fn route_layer<L>(self, layer: L) -> MethodRouter<ReqBody, E> pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<ReqBody, E>
where where
L: Layer<Route<ReqBody, E>>, L: Layer<Route<ReqBody, E>>,
L::Service: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static, L::Service: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static,
<L::Service as Service<Request<ReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<ReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<ReqBody>>>::Future: Send + 'static, <L::Service as Service<Request<ReqBody>>>::Future: Send + 'static,
{ {
let layer = ServiceBuilder::new() let layer_fn = |svc| {
.layer_fn(Route::new) let svc = layer.layer(svc);
.layer(MapResponseLayer::new(IntoResponse::into_response)) let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
.layer(layer) Route::new(svc)
.into_inner(); };
let layer_fn = |s| layer.layer(s);
MethodRouter { self.get = self.get.map(layer_fn);
get: self.get.map(layer_fn), self.head = self.head.map(layer_fn);
head: self.head.map(layer_fn), self.delete = self.delete.map(layer_fn);
delete: self.delete.map(layer_fn), self.options = self.options.map(layer_fn);
options: self.options.map(layer_fn), self.patch = self.patch.map(layer_fn);
patch: self.patch.map(layer_fn), self.post = self.post.map(layer_fn);
post: self.post.map(layer_fn), self.put = self.put.map(layer_fn);
put: self.put.map(layer_fn), self.trace = self.trace.map(layer_fn);
trace: self.trace.map(layer_fn),
fallback: self.fallback, self
allow_header: self.allow_header,
_request_body: PhantomData,
}
} }
#[doc = include_str!("../docs/method_routing/merge.md")] #[doc = include_str!("../docs/method_routing/merge.md")]
pub fn merge(self, other: MethodRouter<ReqBody, E>) -> Self { pub fn merge(mut self, other: MethodRouter<ReqBody, E>) -> Self {
macro_rules! merge { // written using inner functions to generate less IR
( $first:ident, $second:ident ) => { fn merge_inner<T>(name: &str, first: Option<T>, second: Option<T>) -> Option<T> {
match ($first, $second) { match (first, second) {
(Some(_), Some(_)) => panic!(concat!( (Some(_), Some(_)) => panic!(
"Overlapping method route. Cannot merge two method routes that both define `", "Overlapping method route. Cannot merge two method routes that both define `{}`", name
stringify!($first), ),
"`" (Some(svc), None) => Some(svc),
)), (None, Some(svc)) => Some(svc),
(Some(svc), None) => Some(svc), (None, None) => None,
(None, Some(svc)) => Some(svc), }
(None, None) => None, }
fn merge_fallback<B, E>(
fallback: Fallback<B, E>,
fallback_other: Fallback<B, E>,
) -> Fallback<B, E> {
match (fallback, fallback_other) {
(pick @ Fallback::Default(_), Fallback::Default(_)) => pick,
(Fallback::Default(_), pick @ Fallback::Custom(_)) => pick,
(pick @ Fallback::Custom(_), Fallback::Default(_)) => pick,
(Fallback::Custom(_), Fallback::Custom(_)) => {
panic!("Cannot merge two `MethodRouter`s that both have a fallback")
} }
}; }
} }
let Self { self.get = merge_inner("get", self.get, other.get);
get, self.head = merge_inner("head", self.head, other.head);
head, self.delete = merge_inner("delete", self.delete, other.delete);
delete, self.options = merge_inner("options", self.options, other.options);
options, self.patch = merge_inner("patch", self.patch, other.patch);
patch, self.post = merge_inner("post", self.post, other.post);
post, self.put = merge_inner("put", self.put, other.put);
put, self.trace = merge_inner("trace", self.trace, other.trace);
trace,
fallback,
allow_header,
_request_body: _,
} = self;
let Self { self.fallback = merge_fallback(self.fallback, other.fallback);
get: get_other,
head: head_other,
delete: delete_other,
options: options_other,
patch: patch_other,
post: post_other,
put: put_other,
trace: trace_other,
fallback: fallback_other,
allow_header: allow_header_other,
_request_body: _,
} = other;
let get = merge!(get, get_other); self.allow_header = self.allow_header.merge(other.allow_header);
let head = merge!(head, head_other);
let delete = merge!(delete, delete_other);
let options = merge!(options, options_other);
let patch = merge!(patch, patch_other);
let post = merge!(post, post_other);
let put = merge!(put, put_other);
let trace = merge!(trace, trace_other);
let fallback = match (fallback, fallback_other) { self
(pick @ Fallback::Default(_), Fallback::Default(_)) => pick,
(Fallback::Default(_), pick @ Fallback::Custom(_)) => pick,
(pick @ Fallback::Custom(_), Fallback::Default(_)) => pick,
(Fallback::Custom(_), Fallback::Custom(_)) => {
panic!("Cannot merge two `MethodRouter`s that both have a fallback")
}
};
let allow_header = match (allow_header, allow_header_other) {
(AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
(AllowHeader::None, AllowHeader::None) => AllowHeader::None,
(AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
a.extend_from_slice(b",");
a.extend_from_slice(&b);
AllowHeader::Bytes(a)
}
};
Self {
get,
head,
delete,
options,
patch,
post,
put,
trace,
fallback,
allow_header,
_request_body: PhantomData,
}
} }
/// Apply a [`HandleErrorLayer`]. /// Apply a [`HandleErrorLayer`].
@ -882,81 +843,118 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
self.layer(HandleErrorLayer::new(f)) self.layer(HandleErrorLayer::new(f))
} }
fn on_service_boxed_response_body<S>(self, filter: MethodFilter, svc: S) -> Self fn on_service_boxed_response_body<S>(mut self, filter: MethodFilter, svc: S) -> Self
where where
S: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static, S: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static,
S::Response: IntoResponse + 'static, S::Response: IntoResponse + 'static,
S::Future: Send + 'static, S::Future: Send + 'static,
{ {
macro_rules! set_service { // written using an inner function to generate less IR
( fn set_service<T>(
$filter:ident, method_name: &str,
$svc:ident, out: &mut Option<T>,
$allow_header:ident, svc: &T,
[ svc_filter: MethodFilter,
$( filter: MethodFilter,
($out:ident, $variant:ident, [$($method:literal),+]) allow_header: &mut AllowHeader,
),+ methods: &[&'static str],
$(,)? ) where
] T: Clone,
) => { {
$( if svc_filter.contains(filter) {
if $filter.contains(MethodFilter::$variant) { if out.is_some() {
if $out.is_some() { panic!("Overlapping method route. Cannot add two method routes that both handle `{}`", method_name)
panic!("Overlapping method route. Cannot add two method routes that both handle `{}`", stringify!($variant)) }
} *out = Some(svc.clone());
$out = $svc.clone(); for method in methods {
$( append_allow_header(allow_header, method);
append_allow_header(&mut $allow_header, $method); }
)+
}
)+
} }
} }
// written with a pattern match like this to ensure we update all fields let svc = Route::new(svc);
let Self {
mut get, set_service(
mut head, "GET",
mut delete, &mut self.get,
mut options, &svc,
mut patch,
mut post,
mut put,
mut trace,
fallback,
mut allow_header,
_request_body: _,
} = self;
let svc = Some(Route::new(svc));
set_service!(
filter, filter,
svc, MethodFilter::GET,
allow_header, &mut self.allow_header,
[ &["GET", "HEAD"],
(get, GET, ["GET", "HEAD"]),
(head, HEAD, ["HEAD"]),
(delete, DELETE, ["DELETE"]),
(options, OPTIONS, ["OPTIONS"]),
(patch, PATCH, ["PATCH"]),
(post, POST, ["POST"]),
(put, PUT, ["PUT"]),
(trace, TRACE, ["TRACE"]),
]
); );
Self {
get, set_service(
head, "HEAD",
delete, &mut self.head,
options, &svc,
patch, filter,
post, MethodFilter::HEAD,
put, &mut self.allow_header,
trace, &["HEAD"],
fallback, );
allow_header,
_request_body: PhantomData, set_service(
} "TRACE",
&mut self.trace,
&svc,
filter,
MethodFilter::TRACE,
&mut self.allow_header,
&["TRACE"],
);
set_service(
"PUT",
&mut self.put,
&svc,
filter,
MethodFilter::PUT,
&mut self.allow_header,
&["PUT"],
);
set_service(
"POST",
&mut self.post,
&svc,
filter,
MethodFilter::POST,
&mut self.allow_header,
&["POST"],
);
set_service(
"PATCH",
&mut self.patch,
&svc,
filter,
MethodFilter::PATCH,
&mut self.allow_header,
&["PATCH"],
);
set_service(
"OPTIONS",
&mut self.options,
&svc,
filter,
MethodFilter::OPTIONS,
&mut self.allow_header,
&["OPTIONS"],
);
set_service(
"DELETE",
&mut self.delete,
&svc,
filter,
MethodFilter::DELETE,
&mut self.allow_header,
&["DELETE"],
);
self
} }
fn skip_allow_header(mut self) -> Self { fn skip_allow_header(mut self) -> Self {
@ -998,7 +996,6 @@ impl<B, E> Clone for MethodRouter<B, E> {
trace: self.trace.clone(), trace: self.trace.clone(),
fallback: self.fallback.clone(), fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(), allow_header: self.allow_header.clone(),
_request_body: PhantomData,
} }
} }
} }
@ -1056,7 +1053,6 @@ where
trace, trace,
fallback, fallback,
allow_header, allow_header,
_request_body: _,
} = self; } = self;
call!(req, method, HEAD, head); call!(req, method, HEAD, head);
@ -1091,7 +1087,7 @@ mod tests {
use axum_core::response::IntoResponse; use axum_core::response::IntoResponse;
use http::{header::ALLOW, HeaderMap}; use http::{header::ALLOW, HeaderMap};
use std::time::Duration; use std::time::Duration;
use tower::{timeout::TimeoutLayer, Service, ServiceExt}; use tower::{timeout::TimeoutLayer, Service, ServiceBuilder, ServiceExt};
use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir}; use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir};
#[tokio::test] #[tokio::test]

View file

@ -126,12 +126,16 @@ where
T::Response: IntoResponse, T::Response: IntoResponse,
T::Future: Send + 'static, T::Future: Send + 'static,
{ {
if path.is_empty() { fn validate_path(path: &str) {
panic!("Paths must start with a `/`. Use \"/\" for root routes"); if path.is_empty() {
} else if !path.starts_with('/') { panic!("Paths must start with a `/`. Use \"/\" for root routes");
panic!("Paths must start with a `/`"); } else if !path.starts_with('/') {
panic!("Paths must start with a `/`");
}
} }
validate_path(path);
let service = match try_downcast::<Router<B>, _>(service) { let service = match try_downcast::<Router<B>, _>(service) {
Ok(_) => { Ok(_) => {
panic!("Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead") panic!("Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead")
@ -162,16 +166,20 @@ where
Err(service) => Endpoint::Route(Route::new(service)), Err(service) => Endpoint::Route(Route::new(service)),
}; };
self.set_node(path, id);
self.routes.insert(id, service);
self
}
fn set_node(&mut self, path: &str, id: RouteId) {
let mut node = let mut node =
Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone());
if let Err(err) = node.insert(path, id) { if let Err(err) = node.insert(path, id) {
self.panic_on_matchit_error(err); self.panic_on_matchit_error(err);
} }
self.node = Arc::new(node); self.node = Arc::new(node);
self.routes.insert(id, service);
self
} }
#[doc = include_str!("../docs/routing/nest.md")] #[doc = include_str!("../docs/routing/nest.md")]
@ -419,28 +427,38 @@ where
let id = *match_.value; let id = *match_.value;
#[cfg(feature = "matched-path")] #[cfg(feature = "matched-path")]
if let Some(matched_path) = self.node.route_id_to_path.get(&id) { {
use crate::extract::MatchedPath; fn set_matched_path(
id: RouteId,
route_id_to_path: &HashMap<RouteId, Arc<str>>,
extensions: &mut http::Extensions,
) {
if let Some(matched_path) = route_id_to_path.get(&id) {
use crate::extract::MatchedPath;
let matched_path = if let Some(previous) = req.extensions_mut().get::<MatchedPath>() { let matched_path = if let Some(previous) = extensions.get::<MatchedPath>() {
// a previous `MatchedPath` might exist if we're inside a nested Router // a previous `MatchedPath` might exist if we're inside a nested Router
let previous = if let Some(previous) = let previous = if let Some(previous) =
previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE) previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE)
{ {
previous previous
} else {
previous.as_str()
};
let matched_path = format!("{}{}", previous, matched_path);
matched_path.into()
} else {
Arc::clone(matched_path)
};
extensions.insert(MatchedPath(matched_path));
} else { } else {
previous.as_str() #[cfg(debug_assertions)]
}; panic!("should always have a matched path for a route id");
}
}
let matched_path = format!("{}{}", previous, matched_path); set_matched_path(id, &self.node.route_id_to_path, req.extensions_mut());
matched_path.into()
} else {
Arc::clone(matched_path)
};
req.extensions_mut().insert(MatchedPath(matched_path));
} else {
#[cfg(debug_assertions)]
panic!("should always have a matched path for a route id");
} }
url_params::insert_url_params(req.extensions_mut(), match_.params); url_params::insert_url_params(req.extensions_mut(), match_.params);

View file

@ -6,7 +6,7 @@ use axum_core::response::IntoResponse;
use bytes::Bytes; use bytes::Bytes;
use http::{ use http::{
header::{self, CONTENT_LENGTH}, header::{self, CONTENT_LENGTH},
HeaderValue, Request, HeaderMap, HeaderValue, Request,
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::{ use std::{
@ -161,10 +161,10 @@ where
res.extensions_mut().insert(AlreadyPassedThroughRouteFuture); res.extensions_mut().insert(AlreadyPassedThroughRouteFuture);
} }
set_allow_header(&mut res, this.allow_header); set_allow_header(res.headers_mut(), this.allow_header);
// make sure to set content-length before removing the body // make sure to set content-length before removing the body
set_content_length(&mut res); set_content_length(res.size_hint(), res.headers_mut());
let res = if *this.strip_body { let res = if *this.strip_body {
res.map(|_| boxed(Empty::new())) res.map(|_| boxed(Empty::new()))
@ -176,10 +176,10 @@ where
} }
} }
fn set_allow_header<B>(res: &mut Response<B>, allow_header: &mut Option<Bytes>) { fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
match allow_header.take() { match allow_header.take() {
Some(allow_header) if !res.headers().contains_key(header::ALLOW) => { Some(allow_header) if !headers.contains_key(header::ALLOW) => {
res.headers_mut().insert( headers.insert(
header::ALLOW, header::ALLOW,
HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"), HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
); );
@ -188,15 +188,12 @@ fn set_allow_header<B>(res: &mut Response<B>, allow_header: &mut Option<Bytes>)
} }
} }
fn set_content_length<B>(res: &mut Response<B>) fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
where if headers.contains_key(CONTENT_LENGTH) {
B: HttpBody,
{
if res.headers().contains_key(CONTENT_LENGTH) {
return; return;
} }
if let Some(size) = res.size_hint().exact() { if let Some(size) = size_hint.exact() {
let header_value = if size == 0 { let header_value = if size == 0 {
#[allow(clippy::declare_interior_mutable_const)] #[allow(clippy::declare_interior_mutable_const)]
const ZERO: HeaderValue = HeaderValue::from_static("0"); const ZERO: HeaderValue = HeaderValue::from_static("0");
@ -207,7 +204,7 @@ where
HeaderValue::from_str(buffer.format(size)).unwrap() HeaderValue::from_str(buffer.format(size)).unwrap()
}; };
res.headers_mut().insert(CONTENT_LENGTH, header_value); headers.insert(CONTENT_LENGTH, header_value);
} }
} }