Improve build times by generating less IR (#1192)

* example

* `MethodRouter::merge`

* `set_content_length` and `set_allow_header`

* `MethodRouter::on_service_boxed_response_body`

* `Router::route`

* `MethodRouter::merge` again

* `MethodRouter::on_service_boxed_response_body`

* `Router::call_route`

* `MethodRouter::{layer, route_layer}`

* revert example

* fix test

* move function to method on `AllowHeader`
This commit is contained in:
David Pedersen 2022-07-25 20:06:37 +02:00 committed by GitHub
parent 1ace8554ce
commit 234c8ccb13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 231 additions and 220 deletions

View file

@ -13,10 +13,9 @@ use bytes::BytesMut;
use std::{
convert::Infallible,
fmt,
marker::PhantomData,
task::{Context, Poll},
};
use tower::{service_fn, util::MapResponseLayer, ServiceBuilder};
use tower::{service_fn, util::MapResponseLayer};
use tower_layer::Layer;
use tower_service::Service;
@ -482,7 +481,6 @@ pub struct MethodRouter<B = Body, E = Infallible> {
trace: Option<Route<B, E>>,
fallback: Fallback<B, E>,
allow_header: AllowHeader,
_request_body: PhantomData<fn() -> (B, E)>,
}
#[derive(Clone)]
@ -495,6 +493,22 @@ enum AllowHeader {
Bytes(BytesMut),
}
impl AllowHeader {
fn merge(self, other: Self) -> Self {
match (self, other) {
(AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
(AllowHeader::None, AllowHeader::None) => AllowHeader::None,
(AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
a.extend_from_slice(b",");
a.extend_from_slice(&b);
AllowHeader::Bytes(a)
}
}
}
}
impl<B, E> fmt::Debug for MethodRouter<B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MethodRouter")
@ -532,7 +546,6 @@ impl<B, E> MethodRouter<B, E> {
trace: None,
allow_header: AllowHeader::None,
fallback: Fallback::Default(fallback),
_request_body: PhantomData,
}
}
}
@ -723,12 +736,11 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
{
let layer = ServiceBuilder::new()
.layer_fn(Route::new)
.layer(MapResponseLayer::new(IntoResponse::into_response))
.layer(layer)
.into_inner();
let layer_fn = |s| layer.layer(s);
let layer_fn = |svc| {
let svc = layer.layer(svc);
let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
Route::new(svc)
};
MethodRouter {
get: self.get.map(layer_fn),
@ -741,128 +753,77 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
trace: self.trace.map(layer_fn),
fallback: self.fallback.map(layer_fn),
allow_header: self.allow_header,
_request_body: PhantomData,
}
}
#[doc = include_str!("../docs/method_routing/route_layer.md")]
pub fn route_layer<L>(self, layer: L) -> MethodRouter<ReqBody, E>
pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<ReqBody, E>
where
L: Layer<Route<ReqBody, E>>,
L::Service: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static,
<L::Service as Service<Request<ReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<ReqBody>>>::Future: Send + 'static,
{
let layer = ServiceBuilder::new()
.layer_fn(Route::new)
.layer(MapResponseLayer::new(IntoResponse::into_response))
.layer(layer)
.into_inner();
let layer_fn = |s| layer.layer(s);
let layer_fn = |svc| {
let svc = layer.layer(svc);
let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
Route::new(svc)
};
MethodRouter {
get: self.get.map(layer_fn),
head: self.head.map(layer_fn),
delete: self.delete.map(layer_fn),
options: self.options.map(layer_fn),
patch: self.patch.map(layer_fn),
post: self.post.map(layer_fn),
put: self.put.map(layer_fn),
trace: self.trace.map(layer_fn),
fallback: self.fallback,
allow_header: self.allow_header,
_request_body: PhantomData,
}
self.get = self.get.map(layer_fn);
self.head = self.head.map(layer_fn);
self.delete = self.delete.map(layer_fn);
self.options = self.options.map(layer_fn);
self.patch = self.patch.map(layer_fn);
self.post = self.post.map(layer_fn);
self.put = self.put.map(layer_fn);
self.trace = self.trace.map(layer_fn);
self
}
#[doc = include_str!("../docs/method_routing/merge.md")]
pub fn merge(self, other: MethodRouter<ReqBody, E>) -> Self {
macro_rules! merge {
( $first:ident, $second:ident ) => {
match ($first, $second) {
(Some(_), Some(_)) => panic!(concat!(
"Overlapping method route. Cannot merge two method routes that both define `",
stringify!($first),
"`"
)),
(Some(svc), None) => Some(svc),
(None, Some(svc)) => Some(svc),
(None, None) => None,
pub fn merge(mut self, other: MethodRouter<ReqBody, E>) -> Self {
// written using inner functions to generate less IR
fn merge_inner<T>(name: &str, first: Option<T>, second: Option<T>) -> Option<T> {
match (first, second) {
(Some(_), Some(_)) => panic!(
"Overlapping method route. Cannot merge two method routes that both define `{}`", name
),
(Some(svc), None) => Some(svc),
(None, Some(svc)) => Some(svc),
(None, None) => None,
}
}
fn merge_fallback<B, E>(
fallback: Fallback<B, E>,
fallback_other: Fallback<B, E>,
) -> Fallback<B, E> {
match (fallback, fallback_other) {
(pick @ Fallback::Default(_), Fallback::Default(_)) => pick,
(Fallback::Default(_), pick @ Fallback::Custom(_)) => pick,
(pick @ Fallback::Custom(_), Fallback::Default(_)) => pick,
(Fallback::Custom(_), Fallback::Custom(_)) => {
panic!("Cannot merge two `MethodRouter`s that both have a fallback")
}
};
}
}
let Self {
get,
head,
delete,
options,
patch,
post,
put,
trace,
fallback,
allow_header,
_request_body: _,
} = self;
self.get = merge_inner("get", self.get, other.get);
self.head = merge_inner("head", self.head, other.head);
self.delete = merge_inner("delete", self.delete, other.delete);
self.options = merge_inner("options", self.options, other.options);
self.patch = merge_inner("patch", self.patch, other.patch);
self.post = merge_inner("post", self.post, other.post);
self.put = merge_inner("put", self.put, other.put);
self.trace = merge_inner("trace", self.trace, other.trace);
let Self {
get: get_other,
head: head_other,
delete: delete_other,
options: options_other,
patch: patch_other,
post: post_other,
put: put_other,
trace: trace_other,
fallback: fallback_other,
allow_header: allow_header_other,
_request_body: _,
} = other;
self.fallback = merge_fallback(self.fallback, other.fallback);
let get = merge!(get, get_other);
let head = merge!(head, head_other);
let delete = merge!(delete, delete_other);
let options = merge!(options, options_other);
let patch = merge!(patch, patch_other);
let post = merge!(post, post_other);
let put = merge!(put, put_other);
let trace = merge!(trace, trace_other);
self.allow_header = self.allow_header.merge(other.allow_header);
let fallback = match (fallback, fallback_other) {
(pick @ Fallback::Default(_), Fallback::Default(_)) => pick,
(Fallback::Default(_), pick @ Fallback::Custom(_)) => pick,
(pick @ Fallback::Custom(_), Fallback::Default(_)) => pick,
(Fallback::Custom(_), Fallback::Custom(_)) => {
panic!("Cannot merge two `MethodRouter`s that both have a fallback")
}
};
let allow_header = match (allow_header, allow_header_other) {
(AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
(AllowHeader::None, AllowHeader::None) => AllowHeader::None,
(AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
a.extend_from_slice(b",");
a.extend_from_slice(&b);
AllowHeader::Bytes(a)
}
};
Self {
get,
head,
delete,
options,
patch,
post,
put,
trace,
fallback,
allow_header,
_request_body: PhantomData,
}
self
}
/// Apply a [`HandleErrorLayer`].
@ -882,81 +843,118 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
self.layer(HandleErrorLayer::new(f))
}
fn on_service_boxed_response_body<S>(self, filter: MethodFilter, svc: S) -> Self
fn on_service_boxed_response_body<S>(mut self, filter: MethodFilter, svc: S) -> Self
where
S: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static,
S::Response: IntoResponse + 'static,
S::Future: Send + 'static,
{
macro_rules! set_service {
(
$filter:ident,
$svc:ident,
$allow_header:ident,
[
$(
($out:ident, $variant:ident, [$($method:literal),+])
),+
$(,)?
]
) => {
$(
if $filter.contains(MethodFilter::$variant) {
if $out.is_some() {
panic!("Overlapping method route. Cannot add two method routes that both handle `{}`", stringify!($variant))
}
$out = $svc.clone();
$(
append_allow_header(&mut $allow_header, $method);
)+
}
)+
// written using an inner function to generate less IR
fn set_service<T>(
method_name: &str,
out: &mut Option<T>,
svc: &T,
svc_filter: MethodFilter,
filter: MethodFilter,
allow_header: &mut AllowHeader,
methods: &[&'static str],
) where
T: Clone,
{
if svc_filter.contains(filter) {
if out.is_some() {
panic!("Overlapping method route. Cannot add two method routes that both handle `{}`", method_name)
}
*out = Some(svc.clone());
for method in methods {
append_allow_header(allow_header, method);
}
}
}
// written with a pattern match like this to ensure we update all fields
let Self {
mut get,
mut head,
mut delete,
mut options,
mut patch,
mut post,
mut put,
mut trace,
fallback,
mut allow_header,
_request_body: _,
} = self;
let svc = Some(Route::new(svc));
set_service!(
let svc = Route::new(svc);
set_service(
"GET",
&mut self.get,
&svc,
filter,
svc,
allow_header,
[
(get, GET, ["GET", "HEAD"]),
(head, HEAD, ["HEAD"]),
(delete, DELETE, ["DELETE"]),
(options, OPTIONS, ["OPTIONS"]),
(patch, PATCH, ["PATCH"]),
(post, POST, ["POST"]),
(put, PUT, ["PUT"]),
(trace, TRACE, ["TRACE"]),
]
MethodFilter::GET,
&mut self.allow_header,
&["GET", "HEAD"],
);
Self {
get,
head,
delete,
options,
patch,
post,
put,
trace,
fallback,
allow_header,
_request_body: PhantomData,
}
set_service(
"HEAD",
&mut self.head,
&svc,
filter,
MethodFilter::HEAD,
&mut self.allow_header,
&["HEAD"],
);
set_service(
"TRACE",
&mut self.trace,
&svc,
filter,
MethodFilter::TRACE,
&mut self.allow_header,
&["TRACE"],
);
set_service(
"PUT",
&mut self.put,
&svc,
filter,
MethodFilter::PUT,
&mut self.allow_header,
&["PUT"],
);
set_service(
"POST",
&mut self.post,
&svc,
filter,
MethodFilter::POST,
&mut self.allow_header,
&["POST"],
);
set_service(
"PATCH",
&mut self.patch,
&svc,
filter,
MethodFilter::PATCH,
&mut self.allow_header,
&["PATCH"],
);
set_service(
"OPTIONS",
&mut self.options,
&svc,
filter,
MethodFilter::OPTIONS,
&mut self.allow_header,
&["OPTIONS"],
);
set_service(
"DELETE",
&mut self.delete,
&svc,
filter,
MethodFilter::DELETE,
&mut self.allow_header,
&["DELETE"],
);
self
}
fn skip_allow_header(mut self) -> Self {
@ -998,7 +996,6 @@ impl<B, E> Clone for MethodRouter<B, E> {
trace: self.trace.clone(),
fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
_request_body: PhantomData,
}
}
}
@ -1056,7 +1053,6 @@ where
trace,
fallback,
allow_header,
_request_body: _,
} = self;
call!(req, method, HEAD, head);
@ -1091,7 +1087,7 @@ mod tests {
use axum_core::response::IntoResponse;
use http::{header::ALLOW, HeaderMap};
use std::time::Duration;
use tower::{timeout::TimeoutLayer, Service, ServiceExt};
use tower::{timeout::TimeoutLayer, Service, ServiceBuilder, ServiceExt};
use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir};
#[tokio::test]

View file

@ -126,12 +126,16 @@ where
T::Response: IntoResponse,
T::Future: Send + 'static,
{
if path.is_empty() {
panic!("Paths must start with a `/`. Use \"/\" for root routes");
} else if !path.starts_with('/') {
panic!("Paths must start with a `/`");
fn validate_path(path: &str) {
if path.is_empty() {
panic!("Paths must start with a `/`. Use \"/\" for root routes");
} else if !path.starts_with('/') {
panic!("Paths must start with a `/`");
}
}
validate_path(path);
let service = match try_downcast::<Router<B>, _>(service) {
Ok(_) => {
panic!("Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead")
@ -162,16 +166,20 @@ where
Err(service) => Endpoint::Route(Route::new(service)),
};
self.set_node(path, id);
self.routes.insert(id, service);
self
}
fn set_node(&mut self, path: &str, id: RouteId) {
let mut node =
Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone());
if let Err(err) = node.insert(path, id) {
self.panic_on_matchit_error(err);
}
self.node = Arc::new(node);
self.routes.insert(id, service);
self
}
#[doc = include_str!("../docs/routing/nest.md")]
@ -419,28 +427,38 @@ where
let id = *match_.value;
#[cfg(feature = "matched-path")]
if let Some(matched_path) = self.node.route_id_to_path.get(&id) {
use crate::extract::MatchedPath;
{
fn set_matched_path(
id: RouteId,
route_id_to_path: &HashMap<RouteId, Arc<str>>,
extensions: &mut http::Extensions,
) {
if let Some(matched_path) = route_id_to_path.get(&id) {
use crate::extract::MatchedPath;
let matched_path = if let Some(previous) = req.extensions_mut().get::<MatchedPath>() {
// a previous `MatchedPath` might exist if we're inside a nested Router
let previous = if let Some(previous) =
previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE)
{
previous
let matched_path = if let Some(previous) = extensions.get::<MatchedPath>() {
// a previous `MatchedPath` might exist if we're inside a nested Router
let previous = if let Some(previous) =
previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE)
{
previous
} else {
previous.as_str()
};
let matched_path = format!("{}{}", previous, matched_path);
matched_path.into()
} else {
Arc::clone(matched_path)
};
extensions.insert(MatchedPath(matched_path));
} else {
previous.as_str()
};
#[cfg(debug_assertions)]
panic!("should always have a matched path for a route id");
}
}
let matched_path = format!("{}{}", previous, matched_path);
matched_path.into()
} else {
Arc::clone(matched_path)
};
req.extensions_mut().insert(MatchedPath(matched_path));
} else {
#[cfg(debug_assertions)]
panic!("should always have a matched path for a route id");
set_matched_path(id, &self.node.route_id_to_path, req.extensions_mut());
}
url_params::insert_url_params(req.extensions_mut(), match_.params);

View file

@ -6,7 +6,7 @@ use axum_core::response::IntoResponse;
use bytes::Bytes;
use http::{
header::{self, CONTENT_LENGTH},
HeaderValue, Request,
HeaderMap, HeaderValue, Request,
};
use pin_project_lite::pin_project;
use std::{
@ -161,10 +161,10 @@ where
res.extensions_mut().insert(AlreadyPassedThroughRouteFuture);
}
set_allow_header(&mut res, this.allow_header);
set_allow_header(res.headers_mut(), this.allow_header);
// make sure to set content-length before removing the body
set_content_length(&mut res);
set_content_length(res.size_hint(), res.headers_mut());
let res = if *this.strip_body {
res.map(|_| boxed(Empty::new()))
@ -176,10 +176,10 @@ where
}
}
fn set_allow_header<B>(res: &mut Response<B>, allow_header: &mut Option<Bytes>) {
fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
match allow_header.take() {
Some(allow_header) if !res.headers().contains_key(header::ALLOW) => {
res.headers_mut().insert(
Some(allow_header) if !headers.contains_key(header::ALLOW) => {
headers.insert(
header::ALLOW,
HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
);
@ -188,15 +188,12 @@ fn set_allow_header<B>(res: &mut Response<B>, allow_header: &mut Option<Bytes>)
}
}
fn set_content_length<B>(res: &mut Response<B>)
where
B: HttpBody,
{
if res.headers().contains_key(CONTENT_LENGTH) {
fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
if headers.contains_key(CONTENT_LENGTH) {
return;
}
if let Some(size) = res.size_hint().exact() {
if let Some(size) = size_hint.exact() {
let header_value = if size == 0 {
#[allow(clippy::declare_interior_mutable_const)]
const ZERO: HeaderValue = HeaderValue::from_static("0");
@ -207,7 +204,7 @@ where
HeaderValue::from_str(buffer.format(size)).unwrap()
};
res.headers_mut().insert(CONTENT_LENGTH, header_value);
headers.insert(CONTENT_LENGTH, header_value);
}
}