mirror of
https://github.com/tokio-rs/axum.git
synced 2025-04-26 13:56:22 +02:00
Support routing to tower services
This commit is contained in:
parent
433128102b
commit
f983b37fea
2 changed files with 124 additions and 16 deletions
|
@ -21,6 +21,6 @@ tower = { version = "0.4", features = ["util"] }
|
|||
[dev-dependencies]
|
||||
tokio = { version = "1.6.1", features = ["macros", "rt"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tower = { version = "0.4", features = ["util", "make"] }
|
||||
tower = { version = "0.4", features = ["util", "make", "timeout"] }
|
||||
tower-http = { version = "0.1", features = ["trace"] }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
|
|
138
src/lib.rs
138
src/lib.rs
|
@ -41,7 +41,7 @@ use std::{
|
|||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower::{BoxError, Service, ServiceExt};
|
||||
use tower::{BoxError, Layer, Service, ServiceExt};
|
||||
|
||||
pub use hyper::body::Body;
|
||||
|
||||
|
@ -83,6 +83,14 @@ impl<R> RouteAt<R> {
|
|||
self.add_route(handler_fn, Method::GET)
|
||||
}
|
||||
|
||||
pub fn get_service<S, B>(self, service: S) -> RouteBuilder<Route<S, R>>
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response<B>> + Clone,
|
||||
S::Error: Into<BoxError>,
|
||||
{
|
||||
self.add_route_service(service, Method::GET)
|
||||
}
|
||||
|
||||
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
||||
where
|
||||
F: Handler<T>,
|
||||
|
@ -90,16 +98,25 @@ impl<R> RouteAt<R> {
|
|||
self.add_route(handler_fn, Method::POST)
|
||||
}
|
||||
|
||||
pub fn post_service<S, B>(self, service: S) -> RouteBuilder<Route<S, R>>
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response<B>> + Clone,
|
||||
S::Error: Into<BoxError>,
|
||||
{
|
||||
self.add_route_service(service, Method::POST)
|
||||
}
|
||||
|
||||
fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>>
|
||||
where
|
||||
H: Handler<T>,
|
||||
{
|
||||
self.add_route_service(HandlerSvc::new(handler), method)
|
||||
}
|
||||
|
||||
fn add_route_service<S>(self, service: S, method: Method) -> RouteBuilder<Route<S, R>> {
|
||||
let new_app = App {
|
||||
router: Route {
|
||||
handler: HandlerSvc {
|
||||
handler,
|
||||
_input: PhantomData,
|
||||
},
|
||||
service,
|
||||
route_spec: RouteSpec {
|
||||
method,
|
||||
spec: self.route_spec.clone(),
|
||||
|
@ -135,12 +152,28 @@ impl<R> RouteBuilder<R> {
|
|||
self.app.at_bytes(self.route_spec).get(handler_fn)
|
||||
}
|
||||
|
||||
pub fn get_service<S, B>(self, service: S) -> RouteBuilder<Route<S, R>>
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response<B>> + Clone,
|
||||
S::Error: Into<BoxError>,
|
||||
{
|
||||
self.app.at_bytes(self.route_spec).get_service(service)
|
||||
}
|
||||
|
||||
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
||||
where
|
||||
F: Handler<T>,
|
||||
{
|
||||
self.app.at_bytes(self.route_spec).post(handler_fn)
|
||||
}
|
||||
|
||||
pub fn post_service<S, B>(self, service: S) -> RouteBuilder<Route<S, R>>
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response<B>> + Clone,
|
||||
S::Error: Into<BoxError>,
|
||||
{
|
||||
self.app.at_bytes(self.route_spec).post_service(service)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
@ -160,6 +193,18 @@ pub enum Error {
|
|||
|
||||
#[error("failed generating the response body")]
|
||||
ResponseBody(#[source] BoxError),
|
||||
|
||||
#[error("handler service returned an error")]
|
||||
Service(#[source] BoxError),
|
||||
}
|
||||
|
||||
impl From<BoxError> for Error {
|
||||
fn from(err: BoxError) -> Self {
|
||||
match err.downcast::<Error>() {
|
||||
Ok(err) => *err,
|
||||
Err(err) => Error::Service(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Infallible> for Error {
|
||||
|
@ -170,10 +215,17 @@ impl From<Infallible> for Error {
|
|||
|
||||
// TODO(david): make this trait sealed
|
||||
#[async_trait]
|
||||
pub trait Handler<Out> {
|
||||
pub trait Handler<In>: Sized {
|
||||
type ResponseBody;
|
||||
|
||||
async fn call(self, req: Request<Body>) -> Result<Response<Self::ResponseBody>, Error>;
|
||||
|
||||
fn layer<L>(self, layer: L) -> Layered<L::Service, In>
|
||||
where
|
||||
L: Layer<HandlerSvc<Self, In>>,
|
||||
{
|
||||
Layered::new(layer.layer(HandlerSvc::new(self)))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
@ -237,11 +289,60 @@ macro_rules! impl_handler {
|
|||
|
||||
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
|
||||
|
||||
pub struct Layered<S, T> {
|
||||
svc: S,
|
||||
_input: PhantomData<fn() -> T>,
|
||||
}
|
||||
|
||||
impl<S, T> Clone for Layered<S, T>
|
||||
where
|
||||
S: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self::new(self.svc.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S, T, B> Handler<T> for Layered<S, T>
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response<B>> + Send,
|
||||
S::Error: Into<BoxError>,
|
||||
S::Future: Send,
|
||||
{
|
||||
type ResponseBody = B;
|
||||
|
||||
async fn call(self, req: Request<Body>) -> Result<Response<Self::ResponseBody>, Error> {
|
||||
self.svc
|
||||
.oneshot(req)
|
||||
.await
|
||||
.map_err(|err| Error::from(err.into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Layered<S, T> {
|
||||
fn new(svc: S) -> Self {
|
||||
Self {
|
||||
svc,
|
||||
_input: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct HandlerSvc<H, T> {
|
||||
handler: H,
|
||||
_input: PhantomData<fn() -> T>,
|
||||
}
|
||||
|
||||
impl<H, T> HandlerSvc<H, T> {
|
||||
fn new(handler: H) -> Self {
|
||||
Self {
|
||||
handler,
|
||||
_input: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<H, T> Clone for HandlerSvc<H, T>
|
||||
where
|
||||
H: Clone,
|
||||
|
@ -382,7 +483,7 @@ impl<R> Service<R> for EmptyRouter {
|
|||
}
|
||||
|
||||
pub struct Route<H, F> {
|
||||
handler: H,
|
||||
service: H,
|
||||
route_spec: RouteSpec,
|
||||
fallback: F,
|
||||
handler_ready: bool,
|
||||
|
@ -396,7 +497,7 @@ where
|
|||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
handler: self.handler.clone(),
|
||||
service: self.service.clone(),
|
||||
fallback: self.fallback.clone(),
|
||||
route_spec: self.route_spec.clone(),
|
||||
// important to reset readiness when cloning
|
||||
|
@ -438,7 +539,7 @@ where
|
|||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
loop {
|
||||
if !self.handler_ready {
|
||||
ready!(self.handler.poll_ready(cx)).map_err(Into::into)?;
|
||||
ready!(self.service.poll_ready(cx)).map_err(Into::into)?;
|
||||
self.handler_ready = true;
|
||||
}
|
||||
|
||||
|
@ -460,7 +561,7 @@ where
|
|||
"handler not ready. Did you forget to call `poll_ready`?"
|
||||
);
|
||||
self.handler_ready = false;
|
||||
future::Either::Left(BoxResponseBody(self.handler.call(req)))
|
||||
future::Either::Left(BoxResponseBody(self.service.call(req)))
|
||||
} else {
|
||||
assert!(
|
||||
self.fallback_ready,
|
||||
|
@ -542,8 +643,11 @@ mod tests {
|
|||
#![allow(warnings)]
|
||||
use super::*;
|
||||
use hyper::Server;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, net::SocketAddr};
|
||||
use tower::{make::Shared, ServiceBuilder};
|
||||
use tower::{
|
||||
layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder,
|
||||
};
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -559,11 +663,13 @@ mod tests {
|
|||
username: String,
|
||||
}
|
||||
|
||||
async fn root(_: Request<Body>) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from("Hello, World!")))
|
||||
}
|
||||
|
||||
let mut app = app()
|
||||
.at("/")
|
||||
.get(|_: Request<Body>| async {
|
||||
Ok::<_, Error>(Response::new(Body::from("Hello, World!")))
|
||||
})
|
||||
.get(root.layer(TimeoutLayer::new(Duration::from_secs(30))))
|
||||
.at("/users")
|
||||
.get(|_: Request<Body>, pagination: Query<Pagination>| async {
|
||||
let pagination = pagination.into_inner();
|
||||
|
@ -577,7 +683,9 @@ mod tests {
|
|||
assert_eq!(payload.username, "bob");
|
||||
|
||||
Ok::<_, Error>(Response::new(Body::from("users#create")))
|
||||
});
|
||||
})
|
||||
.at("/service")
|
||||
.get_service(service_fn(root));
|
||||
|
||||
let res = app
|
||||
.ready()
|
||||
|
|
Loading…
Add table
Reference in a new issue