1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-26 13:56:22 +02:00

Change routing DSL

This commit is contained in:
David Pedersen 2021-06-04 01:00:48 +02:00
parent e156bc40e1
commit c3977d0b71
8 changed files with 736 additions and 1014 deletions

View file

@ -19,6 +19,7 @@ serde_urlencoded = "0.7"
thiserror = "1.0"
tower = { version = "0.4", features = ["util", "buffer"] }
tower-http = { version = "0.1", features = ["add-extension"] }
regex = "1.5"
[dev-dependencies]
hyper = { version = "0.14", features = ["full"] }

View file

@ -2,20 +2,14 @@ use http::{Request, StatusCode};
use hyper::Server;
use std::net::SocketAddr;
use tower::make::Shared;
use tower_web::{body::Body, extract, response::Html};
use tower_web::{body::Body, response, get, route, AddRoute, extract};
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
// build our application with some routes
let app = tower_web::app()
.at("/")
.get(handler)
.at("/greet/:name")
.get(greet)
// convert it into a `Service`
.into_service();
let app = route("/", get(handler)).route("/greet/:name", get(greet));
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
@ -24,8 +18,8 @@ async fn main() {
server.await.unwrap();
}
async fn handler(_req: Request<Body>) -> Html<&'static str> {
Html("<h1>Hello, World!</h1>")
async fn handler(_req: Request<Body>) -> response::Html<&'static str> {
response::Html("<h1>Hello, World!</h1>")
}
async fn greet(_req: Request<Body>, params: extract::UrlParamsMap) -> Result<String, StatusCode> {

View file

@ -14,7 +14,7 @@ use tower_http::{
use tower_web::{
body::Body,
extract::{BytesMaxLength, Extension, UrlParams},
handler::Handler,
get, route, Handler,
};
#[tokio::main]
@ -22,12 +22,10 @@ async fn main() {
tracing_subscriber::fmt::init();
// build our application with some routes
let app = tower_web::app()
.at("/:key")
.get(get.layer(CompressionLayer::new()))
.post(set)
// convert it into a `Service`
.into_service();
let app = route(
"/:key",
get(kv_get.layer(CompressionLayer::new())).post(kv_set),
);
// add some middleware
let app = ServiceBuilder::new()
@ -50,7 +48,7 @@ struct State {
db: HashMap<String, Bytes>,
}
async fn get(
async fn kv_get(
_req: Request<Body>,
UrlParams((key,)): UrlParams<(String,)>,
Extension(state): Extension<SharedState>,
@ -64,7 +62,7 @@ async fn get(
}
}
async fn set(
async fn kv_set(
_req: Request<Body>,
UrlParams((key,)): UrlParams<(String,)>,
BytesMaxLength(value): BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb

View file

@ -1,232 +0,0 @@
use http::Request;
use hyper::Server;
use std::net::SocketAddr;
use tower::make::Shared;
use tower_web::body::Body;
#[tokio::main]
async fn main() {
// 100 routes should still compile in a reasonable amount of time
// add a .boxed() every 10 routes to improve compile times
let app = tower_web::app()
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.boxed()
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.boxed()
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.boxed()
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.boxed()
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.boxed()
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.boxed()
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.boxed()
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.boxed()
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.boxed()
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.at("/")
.get(handler)
.boxed()
.into_service();
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
let server = Server::bind(&addr).serve(Shared::new(app));
server.await.unwrap();
}
async fn handler(_req: Request<Body>) -> &'static str {
"Hello, World!"
}

View file

@ -143,11 +143,11 @@ impl<S, T> Layered<S, T> {
}
}
pub fn handle_error<F, B, Res>(self, f: F) -> Layered<HandleError<S, F, S::Error>, T>
pub fn handle_error<F, B, Res>(self, f: F) -> Layered<HandleError<S, F>, T>
where
S: Service<Request<Body>, Response = Response<B>>,
F: FnOnce(S::Error) -> Res,
Res: IntoResponse<Body>,
Res: IntoResponse<B>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{

View file

@ -1,13 +1,12 @@
use self::{
body::Body,
routing::{AlwaysNotFound, RouteAt},
};
use self::body::Body;
use body::BoxBody;
use bytes::Bytes;
use futures_util::ready;
use http::{Request, Response};
use handler::HandlerSvc;
use http::{Method, Request, Response};
use pin_project::pin_project;
use response::IntoResponse;
use routing::{EmptyRouter, OnMethod, Route};
use std::{
convert::Infallible,
fmt,
@ -15,7 +14,7 @@ use std::{
pin::Pin,
task::{Context, Poll},
};
use tower::{BoxError, Service};
use tower::{util::Oneshot, BoxError, Service, ServiceExt as _};
pub mod body;
pub mod extract;
@ -23,66 +22,81 @@ pub mod handler;
pub mod response;
pub mod routing;
#[doc(inline)]
pub use self::handler::Handler;
#[doc(inline)]
pub use self::routing::AddRoute;
pub use async_trait::async_trait;
pub use tower_http::add_extension::{AddExtension, AddExtensionLayer};
#[derive(Debug, Copy, Clone)]
pub enum MethodFilter {
Any,
Connect,
Delete,
Get,
Head,
Options,
Patch,
Post,
Put,
Trace,
}
impl MethodFilter {
#[allow(clippy::match_like_matches_macro)]
fn matches(self, method: &Method) -> bool {
use MethodFilter::*;
match (self, method) {
(Any, _)
| (Connect, &Method::CONNECT)
| (Delete, &Method::DELETE)
| (Get, &Method::GET)
| (Head, &Method::HEAD)
| (Options, &Method::OPTIONS)
| (Patch, &Method::PATCH)
| (Post, &Method::POST)
| (Put, &Method::PUT)
| (Trace, &Method::TRACE) => true,
_ => false,
}
}
}
pub fn route<S>(spec: &str, svc: S) -> Route<S, EmptyRouter>
where
S: Service<Request<Body>, Error = Infallible> + Clone,
{
routing::EmptyRouter.route(spec, svc)
}
pub fn get<H, B, T>(handler: H) -> OnMethod<HandlerSvc<H, B, T>, EmptyRouter>
where
H: Handler<B, T>,
{
on_method(MethodFilter::Get, HandlerSvc::new(handler))
}
pub fn post<H, B, T>(handler: H) -> OnMethod<HandlerSvc<H, B, T>, EmptyRouter>
where
H: Handler<B, T>,
{
on_method(MethodFilter::Post, HandlerSvc::new(handler))
}
pub fn on_method<S>(method: MethodFilter, svc: S) -> OnMethod<S, EmptyRouter> {
OnMethod {
method,
svc,
fallback: EmptyRouter,
}
}
#[cfg(test)]
mod tests;
pub fn app() -> App<AlwaysNotFound> {
App {
service_tree: AlwaysNotFound(()),
}
}
#[derive(Debug, Clone)]
pub struct App<R> {
service_tree: R,
}
impl<R> App<R> {
fn new(service_tree: R) -> Self {
Self { service_tree }
}
pub fn at(self, route_spec: &str) -> RouteAt<R> {
self.at_bytes(Bytes::copy_from_slice(route_spec.as_bytes()))
}
fn at_bytes(self, route_spec: Bytes) -> RouteAt<R> {
RouteAt {
app: self,
route_spec,
}
}
}
#[derive(Clone)]
pub struct IntoService<R> {
service_tree: R
}
impl<R, B, T> Service<T> for IntoService<R>
where
R: Service<T, Response = Response<B>, Error = Infallible>,
B: Default,
{
type Response = Response<B>;
type Error = Infallible;
type Future = R::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match ready!(self.service_tree.poll_ready(cx)) {
Ok(_) => Poll::Ready(Ok(())),
Err(err) => match err {},
}
}
fn call(&mut self, req: T) -> Self::Future {
self.service_tree.call(req)
}
}
pub(crate) trait ResultExt<T> {
fn unwrap_infallible(self) -> T;
}
@ -105,11 +119,11 @@ impl<T> ResultExt<T> for Result<T, Infallible> {
// Fixing that is a breaking change to tower-http so we should wait a bit, but should
// totally fix it at some point.
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub struct BoxStdError(#[source] pub(crate) tower::BoxError);
#[error(transparent)]
pub struct BoxStdError(#[from] pub(crate) tower::BoxError);
pub trait ServiceExt<B>: Service<Request<Body>, Response = Response<B>> {
fn handle_error<F, Res>(self, f: F) -> HandleError<Self, F, Self::Error>
fn handle_error<F, Res>(self, f: F) -> HandleError<Self, F>
where
Self: Sized,
F: FnOnce(Self::Error) -> Res,
@ -123,53 +137,33 @@ pub trait ServiceExt<B>: Service<Request<Body>, Response = Response<B>> {
impl<S, B> ServiceExt<B> for S where S: Service<Request<Body>, Response = Response<B>> {}
pub struct HandleError<S, F, E> {
#[derive(Clone)]
pub struct HandleError<S, F> {
inner: S,
f: F,
poll_ready_error: Option<E>,
}
impl<S, F, E> HandleError<S, F, E> {
impl<S, F> HandleError<S, F> {
pub(crate) fn new(inner: S, f: F) -> Self {
Self {
inner,
f,
poll_ready_error: None,
}
Self { inner, f }
}
}
impl<S, F, E> fmt::Debug for HandleError<S, F, E>
impl<S, F> fmt::Debug for HandleError<S, F>
where
S: fmt::Debug,
E: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandleError")
.field("inner", &self.inner)
.field("f", &format_args!("{}", std::any::type_name::<F>()))
.field("poll_ready_error", &self.poll_ready_error)
.finish()
}
}
impl<S, F, E> Clone for HandleError<S, F, E>
impl<S, F, B, Res> Service<Request<Body>> for HandleError<S, F>
where
S: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
f: self.f.clone(),
poll_ready_error: None,
}
}
}
impl<S, F, B, Res> Service<Request<Body>> for HandleError<S, F, S::Error>
where
S: Service<Request<Body>, Response = Response<B>>,
S: Service<Request<Body>, Response = Response<B>> + Clone,
F: FnOnce(S::Error) -> Res + Clone,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
@ -177,47 +171,28 @@ where
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = HandleErrorFuture<S::Future, F, S::Error>;
type Future = HandleErrorFuture<Oneshot<S, Request<Body>>, F>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match ready!(self.inner.poll_ready(cx)) {
Ok(_) => Poll::Ready(Ok(())),
Err(err) => {
self.poll_ready_error = Some(err);
Poll::Ready(Ok(()))
}
}
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
if let Some(err) = self.poll_ready_error.take() {
return HandleErrorFuture {
f: Some(self.f.clone()),
kind: Kind::Error(Some(err)),
};
}
HandleErrorFuture {
f: Some(self.f.clone()),
kind: Kind::Future(self.inner.call(req)),
inner: self.inner.clone().oneshot(req),
}
}
}
#[pin_project]
pub struct HandleErrorFuture<Fut, F, E> {
pub struct HandleErrorFuture<Fut, F> {
#[pin]
kind: Kind<Fut, E>,
inner: Fut,
f: Option<F>,
}
#[pin_project(project = KindProj)]
enum Kind<Fut, E> {
Future(#[pin] Fut),
Error(Option<E>),
}
impl<Fut, F, E, B, Res> Future for HandleErrorFuture<Fut, F, E>
impl<Fut, F, E, B, Res> Future for HandleErrorFuture<Fut, F>
where
Fut: Future<Output = Result<Response<B>, E>>,
F: FnOnce(E) -> Res,
@ -230,18 +205,11 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.kind.project() {
KindProj::Future(future) => match ready!(future.poll(cx)) {
Ok(res) => Ok(res.map(BoxBody::new)).into(),
Err(err) => {
let f = this.f.take().unwrap();
let res = f(err).into_response();
Ok(res.map(BoxBody::new)).into()
}
},
KindProj::Error(err) => {
match ready!(this.inner.poll(cx)) {
Ok(res) => Ok(res.map(BoxBody::new)).into(),
Err(err) => {
let f = this.f.take().unwrap();
let res = f(err.take().unwrap()).into_response();
let res = f(err).into_response();
Ok(res.map(BoxBody::new)).into()
}
}

View file

@ -1,369 +1,145 @@
use crate::{
body::{Body, BoxBody},
body::BoxBody,
handler::{Handler, HandlerSvc},
response::IntoResponse,
App, HandleError, IntoService, ResultExt,
MethodFilter, ResultExt,
};
use bytes::Bytes;
use futures_util::{future, ready};
use http::{Method, Request, Response, StatusCode};
use itertools::{EitherOrBoth, Itertools};
use http::{Request, Response, StatusCode};
use hyper::Body;
use itertools::Itertools;
use pin_project::pin_project;
use regex::Regex;
use std::{
borrow::Cow,
convert::Infallible,
fmt,
future::Future,
pin::Pin,
str,
sync::Arc,
task::{Context, Poll},
};
use tower::{
buffer::{Buffer, BufferLayer},
util::BoxService,
buffer::Buffer,
util::{BoxService, Oneshot, ServiceExt},
BoxError, Layer, Service, ServiceBuilder,
};
#[derive(Clone, Copy)]
pub struct AlwaysNotFound(pub(crate) ());
// ===== DSL =====
impl<R> Service<R> for AlwaysNotFound {
type Response = Response<Body>;
type Error = Infallible;
type Future = future::Ready<Result<Self::Response, Self::Error>>;
#[derive(Clone)]
pub struct Route<S, F> {
pub(crate) pattern: PathPattern,
pub(crate) svc: S,
pub(crate) fallback: F,
}
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
#[derive(Clone)]
pub struct OnMethod<S, F> {
pub(crate) method: MethodFilter,
pub(crate) svc: S,
pub(crate) fallback: F,
}
pub trait AddRoute: Sized {
fn route<T>(self, spec: &str, svc: T) -> Route<T, Self>
where
T: Service<Request<Body>, Error = Infallible> + Clone;
}
impl<S, F> Route<S, F> {
pub fn boxed<B>(self) -> BoxRoute<B>
where
Self: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Send + 'static,
<Self as Service<Request<Body>>>::Future: Send,
B: From<String> + 'static,
{
ServiceBuilder::new()
.layer_fn(BoxRoute)
.buffer(1024)
.layer(BoxService::layer())
.service(self)
}
fn call(&mut self, _req: R) -> Self::Future {
let mut res = Response::new(Body::empty());
*res.status_mut() = StatusCode::NOT_FOUND;
future::ok(res)
pub fn layer<L>(self, layer: L) -> Layered<L::Service>
where
L: Layer<Self>,
L::Service: Service<Request<Body>> + Clone,
{
Layered(layer.layer(self))
}
}
#[derive(Debug, Clone)]
pub struct RouteAt<R> {
pub(crate) app: App<R>,
pub(crate) route_spec: Bytes,
impl<S, F> AddRoute for Route<S, F> {
fn route<T>(self, spec: &str, svc: T) -> Route<T, Self>
where
T: Service<Request<Body>, Error = Infallible> + Clone,
{
Route {
pattern: PathPattern::new(spec),
svc,
fallback: self,
}
}
}
macro_rules! define_route_at_methods {
(
RouteAt:
$name:ident,
$svc_method_name:ident,
$method:ident
) => {
pub fn $name<F, B, T>(self, handler_fn: F) -> RouteBuilder<Or<HandlerSvc<F, B, T>, R>>
where
F: Handler<B, T>,
{
self.add_route(handler_fn, Method::$method)
}
pub fn $svc_method_name<S, B>(self, service: S) -> RouteBuilder<Or<S, R>>
where
S: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Clone,
{
self.add_route_service(service, Method::$method)
}
};
(
RouteBuilder:
$name:ident,
$svc_method_name:ident,
$method:ident
) => {
pub fn $name<F, B, T>(self, handler_fn: F) -> RouteBuilder<Or<HandlerSvc<F, B, T>, R>>
where
F: Handler<B, T>,
{
self.app.at_bytes(self.route_spec).$name(handler_fn)
}
pub fn $svc_method_name<S, B>(self, service: S) -> RouteBuilder<Or<S, R>>
where
S: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Clone,
{
self.app.at_bytes(self.route_spec).$svc_method_name(service)
}
};
}
impl<R> RouteAt<R> {
define_route_at_methods!(RouteAt: get, get_service, GET);
define_route_at_methods!(RouteAt: post, post_service, POST);
define_route_at_methods!(RouteAt: put, put_service, PUT);
define_route_at_methods!(RouteAt: patch, patch_service, PATCH);
define_route_at_methods!(RouteAt: delete, delete_service, DELETE);
define_route_at_methods!(RouteAt: head, head_service, HEAD);
define_route_at_methods!(RouteAt: options, options_service, OPTIONS);
define_route_at_methods!(RouteAt: connect, connect_service, CONNECT);
define_route_at_methods!(RouteAt: trace, trace_service, TRACE);
fn add_route<H, B, T>(
self,
handler: H,
method: Method,
) -> RouteBuilder<Or<HandlerSvc<H, B, T>, R>>
impl<S, F> OnMethod<S, F> {
pub fn get<H, B, T>(self, handler: H) -> OnMethod<HandlerSvc<H, B, T>, Self>
where
H: Handler<B, T>,
{
self.add_route_service(HandlerSvc::new(handler), method)
self.with_method(MethodFilter::Get, HandlerSvc::new(handler))
}
fn add_route_service<S>(self, service: S, method: Method) -> RouteBuilder<Or<S, R>> {
let route_spec = self.route_spec.clone();
self.add_route_service_with_spec(service, RouteSpec::new(method, route_spec))
}
fn add_route_service_with_spec<S>(
self,
service: S,
route_spec: RouteSpec,
) -> RouteBuilder<Or<S, R>> {
assert!(
self.route_spec.starts_with(b"/"),
"route spec must start with a slash (`/`)"
);
let new_app = App {
service_tree: Or {
service,
route_spec,
fallback: self.app.service_tree,
handler_ready: false,
fallback_ready: false,
},
};
RouteBuilder {
app: new_app,
route_spec: self.route_spec,
}
}
}
pub struct RouteBuilder<R> {
app: App<R>,
route_spec: Bytes,
}
impl<R> Clone for RouteBuilder<R>
where
R: Clone,
{
fn clone(&self) -> Self {
Self {
app: self.app.clone(),
route_spec: self.route_spec.clone(),
}
}
}
impl<R> RouteBuilder<R> {
fn new(app: App<R>, route_spec: impl Into<Bytes>) -> Self {
Self {
app,
route_spec: route_spec.into(),
}
}
pub fn at(self, route_spec: &str) -> RouteAt<R> {
self.app.at(route_spec)
}
define_route_at_methods!(RouteBuilder: get, get_service, GET);
define_route_at_methods!(RouteBuilder: post, post_service, POST);
define_route_at_methods!(RouteBuilder: put, put_service, PUT);
define_route_at_methods!(RouteBuilder: patch, patch_service, PATCH);
define_route_at_methods!(RouteBuilder: delete, delete_service, DELETE);
define_route_at_methods!(RouteBuilder: head, head_service, HEAD);
define_route_at_methods!(RouteBuilder: options, options_service, OPTIONS);
define_route_at_methods!(RouteBuilder: connect, connect_service, CONNECT);
define_route_at_methods!(RouteBuilder: trace, trace_service, TRACE);
pub fn into_service(self) -> IntoService<R> {
IntoService {
service_tree: self.app.service_tree,
}
}
pub fn layer<L>(self, layer: L) -> RouteBuilder<L::Service>
pub fn post<H, B, T>(self, handler: H) -> OnMethod<HandlerSvc<H, B, T>, Self>
where
L: Layer<R>,
H: Handler<B, T>,
{
let layered = layer.layer(self.app.service_tree);
let app = App::new(layered);
RouteBuilder::new(app, self.route_spec)
self.with_method(MethodFilter::Post, HandlerSvc::new(handler))
}
pub fn handle_error<F, B, Res>(self, f: F) -> RouteBuilder<HandleError<R, F, R::Error>>
where
R: Service<Request<Body>, Response = Response<B>>,
F: FnOnce(R::Error) -> Res,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
let svc = HandleError::new(self.app.service_tree, f);
let app = App::new(svc);
RouteBuilder::new(app, self.route_spec)
}
pub fn boxed<B>(self) -> RouteBuilder<BoxServiceTree<B>>
where
R: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Send + 'static,
R::Future: Send,
B: From<String> + 'static,
{
let svc = ServiceBuilder::new()
.layer(BufferLayer::new(1024))
.layer(BoxService::layer())
.service(self.app.service_tree);
let app = App::new(BoxServiceTree {
inner: svc,
poll_ready_error: None,
});
RouteBuilder::new(app, self.route_spec)
}
}
pub struct Or<H, F> {
service: H,
route_spec: RouteSpec,
fallback: F,
handler_ready: bool,
fallback_ready: bool,
}
impl<H, F> Clone for Or<H, F>
where
H: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self {
service: self.service.clone(),
fallback: self.fallback.clone(),
route_spec: self.route_spec.clone(),
// important to reset readiness when cloning
handler_ready: false,
fallback_ready: false,
}
}
}
#[derive(Debug, Clone)]
struct RouteSpec {
method: Method,
spec: Bytes,
}
impl RouteSpec {
fn new(method: Method, spec: impl Into<Bytes>) -> Self {
Self {
pub fn with_method<T>(self, method: MethodFilter, svc: T) -> OnMethod<T, Self> {
OnMethod {
method,
spec: spec.into(),
svc,
fallback: self,
}
}
}
impl RouteSpec {
fn matches<B>(&self, req: &Request<B>) -> Option<Vec<(String, String)>> {
// TODO(david): perform this matching outside
if req.method() != self.method {
return None;
}
// ===== Routing service impls =====
let spec_parts = self.spec.split(|b| *b == b'/');
let path = req.uri().path().as_bytes();
let path_parts = path.split(|b| *b == b'/');
let mut params = Vec::new();
for pair in spec_parts.zip_longest(path_parts) {
match pair {
EitherOrBoth::Both(spec, path) => {
if let Some(key) = spec.strip_prefix(b":") {
let key = str::from_utf8(key).unwrap().to_string();
if let Ok(value) = std::str::from_utf8(path) {
params.push((key, value.to_string()));
} else {
return None;
}
} else if spec != path {
return None;
}
}
EitherOrBoth::Left(_) | EitherOrBoth::Right(_) => {
return None;
}
}
}
Some(params)
}
}
impl<H, F, HB, FB> Service<Request<Body>> for Or<H, F>
impl<S, F, SB, FB> Service<Request<Body>> for Route<S, F>
where
H: Service<Request<Body>, Response = Response<HB>, Error = Infallible>,
HB: http_body::Body<Data = Bytes> + Send + Sync + 'static,
HB::Error: Into<BoxError>,
S: Service<Request<Body>, Response = Response<SB>, Error = Infallible> + Clone,
SB: http_body::Body<Data = Bytes> + Send + Sync + 'static,
SB::Error: Into<BoxError>,
F: Service<Request<Body>, Response = Response<FB>, Error = Infallible>,
F: Service<Request<Body>, Response = Response<FB>, Error = Infallible> + Clone,
FB: http_body::Body<Data = Bytes> + Send + Sync + 'static,
FB::Error: Into<BoxError>,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = future::Either<BoxResponseBody<H::Future>, BoxResponseBody<F::Future>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
loop {
if !self.handler_ready {
ready!(self.service.poll_ready(cx)).unwrap_infallible();
self.handler_ready = true;
}
#[allow(clippy::type_complexity)]
type Future = future::Either<
BoxResponseBody<Oneshot<S, Request<Body>>>,
BoxResponseBody<Oneshot<F, Request<Body>>>,
>;
if !self.fallback_ready {
ready!(self.fallback.poll_ready(cx)).unwrap_infallible();
self.fallback_ready = true;
}
if self.handler_ready && self.fallback_ready {
return Poll::Ready(Ok(()));
}
}
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
if let Some(params) = self.route_spec.matches(&req) {
assert!(
self.handler_ready,
"handler not ready. Did you forget to call `poll_ready`?"
);
self.handler_ready = false;
insert_url_params(&mut req, params);
future::Either::Left(BoxResponseBody(self.service.call(req)))
if let Some(captures) = self.pattern.matches(req.uri().path()) {
insert_url_params(&mut req, captures);
let response_future = self.svc.clone().oneshot(req);
future::Either::Left(BoxResponseBody(response_future))
} else {
assert!(
self.fallback_ready,
"fallback not ready. Did you forget to call `poll_ready`?"
);
self.fallback_ready = false;
// TODO(david): this leads to each route creating one box body, probably not great
future::Either::Right(BoxResponseBody(self.fallback.call(req)))
let response_future = self.fallback.clone().oneshot(req);
future::Either::Right(BoxResponseBody(response_future))
}
}
}
@ -371,6 +147,50 @@ where
#[derive(Debug)]
pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>);
fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
let mut current = current.take().unwrap();
current.0.extend(params);
req.extensions_mut().insert(Some(current));
} else {
req.extensions_mut().insert(Some(UrlParams(params)));
}
}
impl<S, F, SB, FB> Service<Request<Body>> for OnMethod<S, F>
where
S: Service<Request<Body>, Response = Response<SB>, Error = Infallible> + Clone,
SB: http_body::Body<Data = Bytes> + Send + Sync + 'static,
SB::Error: Into<BoxError>,
F: Service<Request<Body>, Response = Response<FB>, Error = Infallible> + Clone,
FB: http_body::Body<Data = Bytes> + Send + Sync + 'static,
FB::Error: Into<BoxError>,
{
type Response = Response<BoxBody>;
type Error = Infallible;
#[allow(clippy::type_complexity)]
type Future = future::Either<
BoxResponseBody<Oneshot<S, Request<Body>>>,
BoxResponseBody<Oneshot<F, Request<Body>>>,
>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
if self.method.matches(req.method()) {
let response_future = self.svc.clone().oneshot(req);
future::Either::Left(BoxResponseBody(response_future))
} else {
let response_future = self.fallback.clone().oneshot(req);
future::Either::Right(BoxResponseBody(response_future))
}
}
}
#[pin_project]
pub struct BoxResponseBody<F>(#[pin] F);
@ -392,91 +212,156 @@ where
}
}
pub struct BoxServiceTree<B> {
inner: Buffer<BoxService<Request<Body>, Response<B>, Infallible>, Request<Body>>,
poll_ready_error: Option<BoxError>,
}
#[derive(Clone, Copy)]
pub struct EmptyRouter;
impl<B> Clone for BoxServiceTree<B> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
poll_ready_error: None,
impl AddRoute for EmptyRouter {
fn route<S>(self, spec: &str, svc: S) -> Route<S, Self>
where
S: Service<Request<Body>, Error = Infallible> + Clone,
{
Route {
pattern: PathPattern::new(spec),
svc,
fallback: self,
}
}
}
impl<B> fmt::Debug for BoxServiceTree<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BoxServiceTree").finish()
impl<R> Service<R> for EmptyRouter {
type Response = Response<Body>;
type Error = Infallible;
type Future = future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: R) -> Self::Future {
let mut res = Response::new(Body::empty());
*res.status_mut() = StatusCode::NOT_FOUND;
future::ok(res)
}
}
impl<B> Service<Request<Body>> for BoxServiceTree<B>
// ===== PathPattern =====
#[derive(Debug, Clone)]
pub(crate) struct PathPattern(Arc<Inner>);
#[derive(Debug)]
struct Inner {
full_path_regex: Regex,
capture_group_names: Box<[Bytes]>,
}
impl PathPattern {
pub(crate) fn new(pattern: &str) -> Self {
let mut capture_group_names = Vec::new();
let pattern = pattern
.split('/')
.map(|part| {
if let Some(key) = part.strip_prefix(':') {
capture_group_names.push(Bytes::copy_from_slice(key.as_bytes()));
Cow::Owned(format!("(?P<{}>[^/]*)", key))
} else {
Cow::Borrowed(part)
}
})
.join("/");
let full_path_regex =
Regex::new(&format!("^{}$", pattern)).expect("invalid regex generated from route");
Self(Arc::new(Inner {
full_path_regex,
capture_group_names: capture_group_names.into(),
}))
}
pub(crate) fn matches(&self, path: &str) -> Option<Captures> {
self.0.full_path_regex.captures(path).map(|captures| {
let captures = self
.0
.capture_group_names
.iter()
.map(|bytes| {
std::str::from_utf8(bytes)
.expect("bytes were created from str so is valid utf-8")
})
.filter_map(|name| captures.name(name).map(|value| (name, value.as_str())))
.map(|(key, value)| (key.to_string(), value.to_string()))
.collect::<Vec<_>>();
captures
})
}
}
type Captures = Vec<(String, String)>;
// ===== BoxRoute =====
pub struct BoxRoute<B>(Buffer<BoxService<Request<Body>, Response<B>, Infallible>, Request<Body>>);
impl<B> Clone for BoxRoute<B> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<B> AddRoute for BoxRoute<B> {
fn route<S>(self, spec: &str, svc: S) -> Route<S, Self>
where
S: Service<Request<Body>, Error = Infallible> + Clone,
{
Route {
pattern: PathPattern::new(spec),
svc,
fallback: self,
}
}
}
impl<B> Service<Request<Body>> for BoxRoute<B>
where
B: From<String> + 'static,
{
type Response = Response<B>;
type Error = Infallible;
type Future = BoxServiceTreeResponseFuture<B>;
type Future = BoxRouteResponseFuture<B>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// TODO(david): downcast this into one of the cases in `tower::buffer::error`
// and convert the error into a response. `ServiceError` should never be able to happen
// since all inner services use `Infallible` as the error type.
match ready!(self.inner.poll_ready(cx)) {
Ok(_) => Poll::Ready(Ok(())),
Err(err) => {
self.poll_ready_error = Some(err);
Poll::Ready(Ok(()))
}
}
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: Request<Body>) -> Self::Future {
if let Some(err) = self.poll_ready_error.take() {
return BoxServiceTreeResponseFuture {
kind: Kind::Response(Some(handle_buffer_error(err))),
};
}
BoxServiceTreeResponseFuture {
kind: Kind::Future(self.inner.call(req)),
}
BoxRouteResponseFuture(self.0.clone().oneshot(req))
}
}
#[pin_project]
pub struct BoxServiceTreeResponseFuture<B> {
#[pin]
kind: Kind<B>,
}
pub struct BoxRouteResponseFuture<B>(#[pin] InnerFuture<B>);
#[pin_project(project = KindProj)]
enum Kind<B> {
Response(Option<Response<B>>),
Future(#[pin] InnerFuture<B>),
}
type InnerFuture<B> = tower::buffer::future::ResponseFuture<
Pin<Box<dyn Future<Output = Result<Response<B>, Infallible>> + Send + 'static>>,
type InnerFuture<B> = Oneshot<
Buffer<BoxService<Request<Body>, Response<B>, Infallible>, Request<Body>>,
Request<Body>,
>;
impl<B> Future for BoxServiceTreeResponseFuture<B>
impl<B> Future for BoxRouteResponseFuture<B>
where
B: From<String>,
{
type Output = Result<Response<B>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().kind.project() {
KindProj::Response(res) => Poll::Ready(Ok(res.take().unwrap())),
KindProj::Future(future) => match ready!(future.poll(cx)) {
Ok(res) => Poll::Ready(Ok(res)),
Err(err) => Poll::Ready(Ok(handle_buffer_error(err))),
},
match ready!(self.project().0.poll(cx)) {
Ok(res) => Poll::Ready(Ok(res)),
Err(err) => Poll::Ready(Ok(handle_buffer_error(err))),
}
}
}
@ -516,85 +401,170 @@ where
.unwrap()
}
fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
let mut current = current.take().unwrap();
current.0.extend(params);
req.extensions_mut().insert(Some(current));
} else {
req.extensions_mut().insert(Some(UrlParams(params)));
// ===== Layered =====
#[derive(Clone, Debug)]
pub struct Layered<S>(S);
impl<S> AddRoute for Layered<S> {
fn route<T>(self, spec: &str, svc: T) -> Route<T, Self>
where
T: Service<Request<Body>, Error = Infallible> + Clone,
{
Route {
pattern: PathPattern::new(spec),
svc,
fallback: self,
}
}
}
impl<S> Layered<S> {
pub fn handle_error<F, B, Res>(self, f: F) -> HandleError<Self, F>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
F: FnOnce(S::Error) -> Res,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
HandleError { inner: self, f }
}
}
impl<S, B> Service<Request<Body>> for Layered<S>
where
S: Service<Request<Body>, Response = Response<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
#[inline]
fn call(&mut self, req: Request<Body>) -> Self::Future {
self.0.call(req)
}
}
#[derive(Clone, Copy)]
pub struct HandleError<S, F> {
inner: S,
f: F,
}
impl<S, F> AddRoute for HandleError<S, F> {
fn route<T>(self, spec: &str, svc: T) -> Route<T, Self>
where
T: Service<Request<Body>, Error = Infallible> + Clone,
{
Route {
pattern: PathPattern::new(spec),
svc,
fallback: self,
}
}
}
impl<S, F, B, Res> Service<Request<Body>> for HandleError<S, F>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
F: FnOnce(S::Error) -> Res + Clone,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = HandleErrorFuture<Oneshot<S, Request<Body>>, F>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
HandleErrorFuture {
inner: self.inner.clone().oneshot(req),
f: Some(self.f.clone()),
}
}
}
#[pin_project]
pub struct HandleErrorFuture<Fut, F> {
#[pin]
inner: Fut,
f: Option<F>,
}
impl<Fut, F, B, E, Res> Future for HandleErrorFuture<Fut, F>
where
Fut: Future<Output = Result<Response<B>, E>>,
F: FnOnce(E) -> Res,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match ready!(this.inner.poll(cx)) {
Ok(res) => Ok(res.map(BoxBody::new)).into(),
Err(err) => {
let f = this.f.take().unwrap();
let res = f(err).into_response();
Ok(res.map(BoxBody::new)).into()
}
}
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn test_routing() {
assert_match((Method::GET, "/"), (Method::GET, "/"));
refute_match((Method::GET, "/"), (Method::POST, "/"));
refute_match((Method::POST, "/"), (Method::GET, "/"));
assert_match("/", "/");
assert_match((Method::GET, "/foo"), (Method::GET, "/foo"));
assert_match((Method::GET, "/foo/"), (Method::GET, "/foo/"));
refute_match((Method::GET, "/foo"), (Method::GET, "/foo/"));
refute_match((Method::GET, "/foo/"), (Method::GET, "/foo"));
assert_match("/foo", "/foo");
assert_match("/foo/", "/foo/");
refute_match("/foo", "/foo/");
refute_match("/foo/", "/foo");
assert_match((Method::GET, "/foo/bar"), (Method::GET, "/foo/bar"));
refute_match((Method::GET, "/foo/bar/"), (Method::GET, "/foo/bar"));
refute_match((Method::GET, "/foo/bar"), (Method::GET, "/foo/bar/"));
assert_match("/foo/bar", "/foo/bar");
refute_match("/foo/bar/", "/foo/bar");
refute_match("/foo/bar", "/foo/bar/");
assert_match((Method::GET, "/:value"), (Method::GET, "/foo"));
assert_match((Method::GET, "/users/:id"), (Method::GET, "/users/1"));
assert_match(
(Method::GET, "/users/:id/action"),
(Method::GET, "/users/42/action"),
);
refute_match(
(Method::GET, "/users/:id/action"),
(Method::GET, "/users/42"),
);
refute_match(
(Method::GET, "/users/:id"),
(Method::GET, "/users/42/action"),
assert_match("/:value", "/foo");
assert_match("/users/:id", "/users/1");
assert_match("/users/:id/action", "/users/42/action");
refute_match("/users/:id/action", "/users/42");
refute_match("/users/:id", "/users/42/action");
}
fn assert_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec);
assert!(
route.matches(&path).is_some(),
"`{}` doesn't match `{}`",
path,
route_spec
);
}
fn assert_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) {
let route = RouteSpec::new(route_spec.0.clone(), route_spec.1);
let req = Request::builder()
.method(req_spec.0.clone())
.uri(req_spec.1)
.body(())
.unwrap();
fn refute_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec);
assert!(
route.matches(&req).is_some(),
"`{} {}` doesn't match `{:?} {}`",
req.method(),
req.uri().path(),
route.method,
str::from_utf8(&route.spec).unwrap(),
);
}
fn refute_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) {
let route = RouteSpec::new(route_spec.0.clone(), route_spec.1);
let req = Request::builder()
.method(req_spec.0.clone())
.uri(req_spec.1)
.body(())
.unwrap();
assert!(
route.matches(&req).is_none(),
"`{} {}` shouldn't match `{:?} {}`",
req.method(),
req.uri().path(),
route.method,
str::from_utf8(&route.spec).unwrap(),
route.matches(&path).is_none(),
"`{}` did match `{}` (but shouldn't)",
path,
route_spec
);
}
}

View file

@ -1,4 +1,4 @@
use crate::{app, extract, handler::Handler};
use crate::{extract, get, on_method, post, route, AddRoute, Handler, MethodFilter};
use http::{Request, Response, StatusCode};
use hyper::{Body, Server};
use serde::Deserialize;
@ -7,29 +7,53 @@ use std::{
net::{SocketAddr, TcpListener},
time::Duration,
};
use tower::{make::Shared, BoxError, Service};
use tower::{make::Shared, BoxError, Service, ServiceBuilder};
use tower_http::compression::CompressionLayer;
#[tokio::test]
async fn hello_world() {
let app = app()
.at("/")
.get(|_: Request<Body>| async { "Hello, World!" })
.into_service();
async fn root(_: Request<Body>) -> &'static str {
"Hello, World!"
}
async fn foo(_: Request<Body>) -> &'static str {
"foo"
}
async fn users_create(_: Request<Body>) -> &'static str {
"users#create"
}
let app = route("/", get(root).post(foo)).route("/users", post(users_create));
let addr = run_in_background(app).await;
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
let body = res.text().await.unwrap();
let client = reqwest::Client::new();
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "Hello, World!");
let res = client
.post(format!("http://{}", addr))
.send()
.await
.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "foo");
let res = client
.post(format!("http://{}/users", addr))
.send()
.await
.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "users#create");
}
#[tokio::test]
async fn consume_body() {
let app = app()
.at("/")
.get(|_: Request<Body>, body: String| async { body })
.into_service();
let app = route("/", get(|_: Request<Body>, body: String| async { body }));
let addr = run_in_background(app).await;
@ -52,10 +76,10 @@ async fn deserialize_body() {
foo: String,
}
let app = app()
.at("/")
.post(|_: Request<Body>, input: extract::Json<Input>| async { input.0.foo })
.into_service();
let app = route(
"/",
post(|_: Request<Body>, input: extract::Json<Input>| async { input.0.foo }),
);
let addr = run_in_background(app).await;
@ -78,10 +102,10 @@ async fn consume_body_to_json_requires_json_content_type() {
foo: String,
}
let app = app()
.at("/")
.post(|_: Request<Body>, input: extract::Json<Input>| async { input.0.foo })
.into_service();
let app = route(
"/",
post(|_: Request<Body>, input: extract::Json<Input>| async { input.0.foo }),
);
let addr = run_in_background(app).await;
@ -110,14 +134,14 @@ async fn body_with_length_limit() {
const LIMIT: u64 = 8;
let app = app()
.at("/")
.post(
let app = route(
"/",
post(
|req: Request<Body>, _body: extract::BytesMaxLength<LIMIT>| async move {
dbg!(&req);
},
)
.into_service();
),
);
let addr = run_in_background(app).await;
@ -160,15 +184,16 @@ async fn body_with_length_limit() {
#[tokio::test]
async fn routing() {
let app = app()
.at("/users")
.get(|_: Request<Body>| async { "users#index" })
.post(|_: Request<Body>| async { "users#create" })
.at("/users/:id")
.get(|_: Request<Body>| async { "users#show" })
.at("/users/:id/action")
.get(|_: Request<Body>| async { "users#action" })
.into_service();
let app = route(
"/users",
get(|_: Request<Body>| async { "users#index" })
.post(|_: Request<Body>| async { "users#create" }),
)
.route("/users/:id", get(|_: Request<Body>| async { "users#show" }))
.route(
"/users/:id/action",
get(|_: Request<Body>| async { "users#action" }),
);
let addr = run_in_background(app).await;
@ -212,9 +237,9 @@ async fn routing() {
#[tokio::test]
async fn extracting_url_params() {
let app = app()
.at("/users/:id")
.get(
let app = route(
"/users/:id",
get(
|_: Request<Body>, params: extract::UrlParams<(i32,)>| async move {
let (id,) = params.0;
assert_eq!(id, 42);
@ -225,8 +250,8 @@ async fn extracting_url_params() {
assert_eq!(params_map.get("id").unwrap(), "1337");
assert_eq!(params_map.get_typed::<i32>("id").unwrap(), 1337);
},
)
.into_service();
),
);
let addr = run_in_background(app).await;
@ -249,12 +274,12 @@ async fn extracting_url_params() {
#[tokio::test]
async fn boxing() {
let app = app()
.at("/")
.get(|_: Request<Body>| async { "hi from GET" })
.boxed()
.post(|_: Request<Body>| async { "hi from POST" })
.into_service();
let app = route(
"/",
get(|_: Request<Body>| async { "hi from GET" })
.post(|_: Request<Body>| async { "hi from POST" }),
)
.boxed();
let addr = run_in_background(app).await;
@ -280,24 +305,24 @@ async fn service_handlers() {
use tower::service_fn;
use tower_http::services::ServeFile;
let app = app()
.at("/echo")
.post_service(service_fn(|req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(req.into_body()))
}))
// calling boxed isn't necessary here but done so
// we're sure it compiles
.boxed()
.at("/static/Cargo.toml")
.get_service(
let app = route(
"/echo",
on_method(
MethodFilter::Post,
service_fn(|req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(req.into_body()))
}),
),
)
.route(
"/static/Cargo.toml",
on_method(
MethodFilter::Get,
ServeFile::new("Cargo.toml").handle_error(|error: std::io::Error| {
(StatusCode::INTERNAL_SERVER_ERROR, error.to_string())
}),
)
// calling boxed isn't necessary here but done so
// we're sure it compiles
.boxed()
.into_service();
),
);
let addr = run_in_background(app).await;
@ -331,17 +356,15 @@ async fn middleware_on_single_route() {
"Hello, World!"
}
let app = app()
.at("/")
.get(
handle.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.into_inner(),
),
)
.into_service();
let app = route(
"/",
get(handle.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.into_inner(),
)),
);
let addr = run_in_background(app).await;
@ -360,14 +383,12 @@ async fn handling_errors_from_layered_single_routes() {
""
}
let app = app()
.at("/")
.get(
handle
.layer(TimeoutLayer::new(Duration::from_millis(100)))
.handle_error(|_error: BoxError| StatusCode::INTERNAL_SERVER_ERROR),
)
.into_service();
let app = route(
"/",
get(handle
.layer(TimeoutLayer::new(Duration::from_millis(100)))
.handle_error(|_error: BoxError| StatusCode::INTERNAL_SERVER_ERROR)),
);
let addr = run_in_background(app).await;
@ -377,19 +398,19 @@ async fn handling_errors_from_layered_single_routes() {
#[tokio::test]
async fn layer_on_whole_router() {
use tower::timeout::TimeoutLayer;
async fn handle(_req: Request<Body>) -> &'static str {
tokio::time::sleep(Duration::from_secs(10)).await;
""
}
let app = app()
.at("/")
.get(handle)
.layer(TimeoutLayer::new(Duration::from_millis(100)))
.handle_error(|_err: BoxError| StatusCode::INTERNAL_SERVER_ERROR)
.into_service();
let app = route("/", get(handle))
.layer(
ServiceBuilder::new()
.layer(CompressionLayer::new())
.timeout(Duration::from_millis(100))
.into_inner(),
)
.handle_error(|_err: BoxError| StatusCode::INTERNAL_SERVER_ERROR);
let addr = run_in_background(app).await;
@ -397,148 +418,150 @@ async fn layer_on_whole_router() {
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
// #[tokio::test]
// async fn nesting() {
// let api = app()
// .at("/users")
// .get(|_: Request<Body>| async { "users#index" })
// .post(|_: Request<Body>| async { "users#create" })
// .at("/users/:id")
// .get(
// |_: Request<Body>, params: extract::UrlParams<(i32,)>| async move {
// let (id,) = params.0;
// format!("users#show {}", id)
// },
// );
// TODO(david): layer that changes the response body type to have a different error
// let app = app()
// .at("/foo")
// .get(|_: Request<Body>| async { "foo" })
// .at("/api")
// .nest(api)
// .at("/bar")
// .get(|_: Request<Body>| async { "bar" })
// .into_service();
// // #[tokio::test]
// // async fn nesting() {
// // let api = app()
// // .at("/users")
// // .get(|_: Request<Body>| async { "users#index" })
// // .post(|_: Request<Body>| async { "users#create" })
// // .at("/users/:id")
// // .get(
// // |_: Request<Body>, params: extract::UrlParams<(i32,)>| async move {
// // let (id,) = params.0;
// // format!("users#show {}", id)
// // },
// // );
// let addr = run_in_background(app).await;
// // let app = app()
// // .at("/foo")
// // .get(|_: Request<Body>| async { "foo" })
// // .at("/api")
// // .nest(api)
// // .at("/bar")
// // .get(|_: Request<Body>| async { "bar" })
// // .into_service();
// let client = reqwest::Client::new();
// // let addr = run_in_background(app).await;
// let res = client
// .get(format!("http://{}/api/users", addr))
// .send()
// .await
// .unwrap();
// assert_eq!(res.text().await.unwrap(), "users#index");
// // let client = reqwest::Client::new();
// let res = client
// .post(format!("http://{}/api/users", addr))
// .send()
// .await
// .unwrap();
// assert_eq!(res.text().await.unwrap(), "users#create");
// // let res = client
// // .get(format!("http://{}/api/users", addr))
// // .send()
// // .await
// // .unwrap();
// // assert_eq!(res.text().await.unwrap(), "users#index");
// let res = client
// .get(format!("http://{}/api/users/42", addr))
// .send()
// .await
// .unwrap();
// assert_eq!(res.text().await.unwrap(), "users#show 42");
// // let res = client
// // .post(format!("http://{}/api/users", addr))
// // .send()
// // .await
// // .unwrap();
// // assert_eq!(res.text().await.unwrap(), "users#create");
// let res = client
// .get(format!("http://{}/foo", addr))
// .send()
// .await
// .unwrap();
// assert_eq!(res.text().await.unwrap(), "foo");
// // let res = client
// // .get(format!("http://{}/api/users/42", addr))
// // .send()
// // .await
// // .unwrap();
// // assert_eq!(res.text().await.unwrap(), "users#show 42");
// let res = client
// .get(format!("http://{}/bar", addr))
// .send()
// .await
// .unwrap();
// assert_eq!(res.text().await.unwrap(), "bar");
// }
// // let res = client
// // .get(format!("http://{}/foo", addr))
// // .send()
// // .await
// // .unwrap();
// // assert_eq!(res.text().await.unwrap(), "foo");
// #[tokio::test]
// async fn nesting_with_dynamic_part() {
// let api = app().at("/users/:id").get(
// |_: Request<Body>, params: extract::UrlParamsMap| async move {
// // let (version, id) = params.0;
// dbg!(&params);
// let version = params.get("version").unwrap();
// let id = params.get("id").unwrap();
// format!("users#show {} {}", version, id)
// },
// );
// // let res = client
// // .get(format!("http://{}/bar", addr))
// // .send()
// // .await
// // .unwrap();
// // assert_eq!(res.text().await.unwrap(), "bar");
// // }
// let app = app().at("/:version/api").nest(api).into_service();
// // #[tokio::test]
// // async fn nesting_with_dynamic_part() {
// // let api = app().at("/users/:id").get(
// // |_: Request<Body>, params: extract::UrlParamsMap| async move {
// // // let (version, id) = params.0;
// // dbg!(&params);
// // let version = params.get("version").unwrap();
// // let id = params.get("id").unwrap();
// // format!("users#show {} {}", version, id)
// // },
// // );
// let addr = run_in_background(app).await;
// // let app = app().at("/:version/api").nest(api).into_service();
// let client = reqwest::Client::new();
// // let addr = run_in_background(app).await;
// let res = client
// .get(format!("http://{}/v0/api/users/123", addr))
// .send()
// .await
// .unwrap();
// let status = res.status();
// assert_eq!(res.text().await.unwrap(), "users#show v0 123");
// assert_eq!(status, StatusCode::OK);
// }
// // let client = reqwest::Client::new();
// #[tokio::test]
// async fn nesting_more_deeply() {
// let users_api = app()
// .at("/:id")
// .get(|req: Request<Body>| async move {
// dbg!(&req.uri().path());
// "users#show"
// });
// // let res = client
// // .get(format!("http://{}/v0/api/users/123", addr))
// // .send()
// // .await
// // .unwrap();
// // let status = res.status();
// // assert_eq!(res.text().await.unwrap(), "users#show v0 123");
// // assert_eq!(status, StatusCode::OK);
// // }
// let games_api = app()
// .at("/")
// .post(|req: Request<Body>| async move {
// dbg!(&req.uri().path());
// "games#create"
// });
// // #[tokio::test]
// // async fn nesting_more_deeply() {
// // let users_api = app()
// // .at("/:id")
// // .get(|req: Request<Body>| async move {
// // dbg!(&req.uri().path());
// // "users#show"
// // });
// let api = app()
// .at("/users")
// .nest(users_api)
// .at("/games")
// .nest(games_api);
// // let games_api = app()
// // .at("/")
// // .post(|req: Request<Body>| async move {
// // dbg!(&req.uri().path());
// // "games#create"
// // });
// let app = app().at("/:version/api").nest(api).into_service();
// // let api = app()
// // .at("/users")
// // .nest(users_api)
// // .at("/games")
// // .nest(games_api);
// let addr = run_in_background(app).await;
// // let app = app().at("/:version/api").nest(api).into_service();
// let client = reqwest::Client::new();
// // let addr = run_in_background(app).await;
// // let res = client
// // .get(format!("http://{}/v0/api/users/123", addr))
// // .send()
// // .await
// // .unwrap();
// // assert_eq!(res.status(), StatusCode::OK);
// // let client = reqwest::Client::new();
// println!("============================");
// // // let res = client
// // // .get(format!("http://{}/v0/api/users/123", addr))
// // // .send()
// // // .await
// // // .unwrap();
// // // assert_eq!(res.status(), StatusCode::OK);
// let res = client
// .post(format!("http://{}/v0/api/games", addr))
// .send()
// .await
// .unwrap();
// assert_eq!(res.status(), StatusCode::OK);
// }
// // println!("============================");
// TODO(david): nesting more deeply
// // let res = client
// // .post(format!("http://{}/v0/api/games", addr))
// // .send()
// // .await
// // .unwrap();
// // assert_eq!(res.status(), StatusCode::OK);
// // }
// TODO(david): composing two apps
// TODO(david): composing two apps with one at a "sub path"
// TODO(david): composing two boxed apps
// TODO(david): composing two apps that have had layers applied
// // TODO(david): nesting more deeply
// // TODO(david): composing two apps
// // TODO(david): composing two apps with one at a "sub path"
// // TODO(david): composing two boxed apps
// // TODO(david): composing two apps that have had layers applied
/// Run a `tower::Service` in the background and get a URI for it.
pub async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
@ -547,8 +570,8 @@ where
ResBody: http_body::Body + Send + 'static,
ResBody::Data: Send,
ResBody::Error: Into<BoxError>,
S::Error: Into<BoxError>,
S::Future: Send,
S::Error: Into<BoxError>,
{
let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket");
let addr = listener.local_addr().unwrap();