diff --git a/examples/rest-grpc-multiplex/src/multiplex_service.rs b/examples/rest-grpc-multiplex/src/multiplex_service.rs index 50c83a9e..de6b8b6c 100644 --- a/examples/rest-grpc-multiplex/src/multiplex_service.rs +++ b/examples/rest-grpc-multiplex/src/multiplex_service.rs @@ -1,4 +1,4 @@ -use axum::{body::BoxBody, response::IntoResponse}; +use axum::{body::BoxBody, http::header::CONTENT_TYPE, response::IntoResponse}; use futures::{future::BoxFuture, ready}; use hyper::{Body, Request, Response}; use std::{ @@ -7,7 +7,6 @@ use std::{ }; use tower::Service; -#[derive(Clone)] pub struct MultiplexService { rest: A, rest_ready: bool, @@ -26,6 +25,22 @@ impl MultiplexService { } } +impl Clone for MultiplexService +where + A: Clone, + B: Clone, +{ + fn clone(&self) -> Self { + Self { + rest: self.rest.clone(), + grpc: self.grpc.clone(), + // the cloned services probably wont be ready + rest_ready: false, + grpc_ready: false, + } + } +} + impl Service> for MultiplexService where A: Service, Error = Infallible>, @@ -40,6 +55,7 @@ where type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + // drive readiness for each inner service and record which is ready loop { match (self.rest_ready, self.grpc_ready) { (true, true) => { @@ -47,7 +63,7 @@ where } (false, _) => { ready!(self.rest.poll_ready(cx))?; - self.rest_ready = false; + self.rest_ready = true; } (_, false) => { ready!(self.grpc.poll_ready(cx))?; @@ -58,6 +74,19 @@ where } fn call(&mut self, req: Request) -> Self::Future { + // require users to call `poll_ready` first, if they don't we're allowed to panic + // as per the `tower::Service` contract + assert!( + self.grpc_ready, + "grpc service not ready. Did you forget to call `poll_ready`?" + ); + assert!( + self.rest_ready, + "rest service not ready. Did you forget to call `poll_ready`?" + ); + + // if we get a grpc request call the grpc service, otherwise call the rest service + // when calling a service it becomes not-ready so we have drive readiness again if is_grpc_request(&req) { self.grpc_ready = false; let future = self.grpc.call(req); @@ -78,7 +107,7 @@ where fn is_grpc_request(req: &Request) -> bool { req.headers() - .get("content-type") + .get(CONTENT_TYPE) .map(|content_type| content_type.as_bytes()) .filter(|content_type| content_type.starts_with(b"application/grpc")) .is_some()