1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-26 13:56:22 +02:00

Support routing to tower services

This commit is contained in:
David Pedersen 2021-05-30 03:10:55 +02:00
parent 433128102b
commit f983b37fea
2 changed files with 124 additions and 16 deletions

View file

@ -21,6 +21,6 @@ tower = { version = "0.4", features = ["util"] }
[dev-dependencies]
tokio = { version = "1.6.1", features = ["macros", "rt"] }
serde = { version = "1.0", features = ["derive"] }
tower = { version = "0.4", features = ["util", "make"] }
tower = { version = "0.4", features = ["util", "make", "timeout"] }
tower-http = { version = "0.1", features = ["trace"] }
hyper = { version = "0.14", features = ["full"] }

View file

@ -41,7 +41,7 @@ use std::{
pin::Pin,
task::{Context, Poll},
};
use tower::{BoxError, Service, ServiceExt};
use tower::{BoxError, Layer, Service, ServiceExt};
pub use hyper::body::Body;
@ -83,6 +83,14 @@ impl<R> RouteAt<R> {
self.add_route(handler_fn, Method::GET)
}
pub fn get_service<S, B>(self, service: S) -> RouteBuilder<Route<S, R>>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
S::Error: Into<BoxError>,
{
self.add_route_service(service, Method::GET)
}
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
where
F: Handler<T>,
@ -90,16 +98,25 @@ impl<R> RouteAt<R> {
self.add_route(handler_fn, Method::POST)
}
pub fn post_service<S, B>(self, service: S) -> RouteBuilder<Route<S, R>>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
S::Error: Into<BoxError>,
{
self.add_route_service(service, Method::POST)
}
fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>>
where
H: Handler<T>,
{
self.add_route_service(HandlerSvc::new(handler), method)
}
fn add_route_service<S>(self, service: S, method: Method) -> RouteBuilder<Route<S, R>> {
let new_app = App {
router: Route {
handler: HandlerSvc {
handler,
_input: PhantomData,
},
service,
route_spec: RouteSpec {
method,
spec: self.route_spec.clone(),
@ -135,12 +152,28 @@ impl<R> RouteBuilder<R> {
self.app.at_bytes(self.route_spec).get(handler_fn)
}
pub fn get_service<S, B>(self, service: S) -> RouteBuilder<Route<S, R>>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
S::Error: Into<BoxError>,
{
self.app.at_bytes(self.route_spec).get_service(service)
}
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
where
F: Handler<T>,
{
self.app.at_bytes(self.route_spec).post(handler_fn)
}
pub fn post_service<S, B>(self, service: S) -> RouteBuilder<Route<S, R>>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
S::Error: Into<BoxError>,
{
self.app.at_bytes(self.route_spec).post_service(service)
}
}
#[derive(Debug, thiserror::Error)]
@ -160,6 +193,18 @@ pub enum Error {
#[error("failed generating the response body")]
ResponseBody(#[source] BoxError),
#[error("handler service returned an error")]
Service(#[source] BoxError),
}
impl From<BoxError> for Error {
fn from(err: BoxError) -> Self {
match err.downcast::<Error>() {
Ok(err) => *err,
Err(err) => Error::Service(err),
}
}
}
impl From<Infallible> for Error {
@ -170,10 +215,17 @@ impl From<Infallible> for Error {
// TODO(david): make this trait sealed
#[async_trait]
pub trait Handler<Out> {
pub trait Handler<In>: Sized {
type ResponseBody;
async fn call(self, req: Request<Body>) -> Result<Response<Self::ResponseBody>, Error>;
fn layer<L>(self, layer: L) -> Layered<L::Service, In>
where
L: Layer<HandlerSvc<Self, In>>,
{
Layered::new(layer.layer(HandlerSvc::new(self)))
}
}
#[async_trait]
@ -237,11 +289,60 @@ macro_rules! impl_handler {
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
pub struct Layered<S, T> {
svc: S,
_input: PhantomData<fn() -> T>,
}
impl<S, T> Clone for Layered<S, T>
where
S: Clone,
{
fn clone(&self) -> Self {
Self::new(self.svc.clone())
}
}
#[async_trait]
impl<S, T, B> Handler<T> for Layered<S, T>
where
S: Service<Request<Body>, Response = Response<B>> + Send,
S::Error: Into<BoxError>,
S::Future: Send,
{
type ResponseBody = B;
async fn call(self, req: Request<Body>) -> Result<Response<Self::ResponseBody>, Error> {
self.svc
.oneshot(req)
.await
.map_err(|err| Error::from(err.into()))
}
}
impl<S, T> Layered<S, T> {
fn new(svc: S) -> Self {
Self {
svc,
_input: PhantomData,
}
}
}
pub struct HandlerSvc<H, T> {
handler: H,
_input: PhantomData<fn() -> T>,
}
impl<H, T> HandlerSvc<H, T> {
fn new(handler: H) -> Self {
Self {
handler,
_input: PhantomData,
}
}
}
impl<H, T> Clone for HandlerSvc<H, T>
where
H: Clone,
@ -382,7 +483,7 @@ impl<R> Service<R> for EmptyRouter {
}
pub struct Route<H, F> {
handler: H,
service: H,
route_spec: RouteSpec,
fallback: F,
handler_ready: bool,
@ -396,7 +497,7 @@ where
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
service: self.service.clone(),
fallback: self.fallback.clone(),
route_spec: self.route_spec.clone(),
// important to reset readiness when cloning
@ -438,7 +539,7 @@ where
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
loop {
if !self.handler_ready {
ready!(self.handler.poll_ready(cx)).map_err(Into::into)?;
ready!(self.service.poll_ready(cx)).map_err(Into::into)?;
self.handler_ready = true;
}
@ -460,7 +561,7 @@ where
"handler not ready. Did you forget to call `poll_ready`?"
);
self.handler_ready = false;
future::Either::Left(BoxResponseBody(self.handler.call(req)))
future::Either::Left(BoxResponseBody(self.service.call(req)))
} else {
assert!(
self.fallback_ready,
@ -542,8 +643,11 @@ mod tests {
#![allow(warnings)]
use super::*;
use hyper::Server;
use std::time::Duration;
use std::{fmt, net::SocketAddr};
use tower::{make::Shared, ServiceBuilder};
use tower::{
layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder,
};
use tower_http::trace::TraceLayer;
#[tokio::test]
@ -559,11 +663,13 @@ mod tests {
username: String,
}
async fn root(_: Request<Body>) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from("Hello, World!")))
}
let mut app = app()
.at("/")
.get(|_: Request<Body>| async {
Ok::<_, Error>(Response::new(Body::from("Hello, World!")))
})
.get(root.layer(TimeoutLayer::new(Duration::from_secs(30))))
.at("/users")
.get(|_: Request<Body>, pagination: Query<Pagination>| async {
let pagination = pagination.into_inner();
@ -577,7 +683,9 @@ mod tests {
assert_eq!(payload.username, "bob");
Ok::<_, Error>(Response::new(Body::from("users#create")))
});
})
.at("/service")
.get_service(service_fn(root));
let res = app
.ready()