diff --git a/examples/rest-grpc-multiplex/src/multiplex_service.rs b/examples/rest-grpc-multiplex/src/multiplex_service.rs
index 50c83a9e..de6b8b6c 100644
--- a/examples/rest-grpc-multiplex/src/multiplex_service.rs
+++ b/examples/rest-grpc-multiplex/src/multiplex_service.rs
@@ -1,4 +1,4 @@
-use axum::{body::BoxBody, response::IntoResponse};
+use axum::{body::BoxBody, http::header::CONTENT_TYPE, response::IntoResponse};
use futures::{future::BoxFuture, ready};
use hyper::{Body, Request, Response};
use std::{
@@ -7,7 +7,6 @@ use std::{
};
use tower::Service;
-#[derive(Clone)]
pub struct MultiplexService {
rest: A,
rest_ready: bool,
@@ -26,6 +25,22 @@ impl MultiplexService {
}
}
+impl Clone for MultiplexService
+where
+ A: Clone,
+ B: Clone,
+{
+ fn clone(&self) -> Self {
+ Self {
+ rest: self.rest.clone(),
+ grpc: self.grpc.clone(),
+ // the cloned services probably wont be ready
+ rest_ready: false,
+ grpc_ready: false,
+ }
+ }
+}
+
impl Service> for MultiplexService
where
A: Service, Error = Infallible>,
@@ -40,6 +55,7 @@ where
type Future = BoxFuture<'static, Result>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ // drive readiness for each inner service and record which is ready
loop {
match (self.rest_ready, self.grpc_ready) {
(true, true) => {
@@ -47,7 +63,7 @@ where
}
(false, _) => {
ready!(self.rest.poll_ready(cx))?;
- self.rest_ready = false;
+ self.rest_ready = true;
}
(_, false) => {
ready!(self.grpc.poll_ready(cx))?;
@@ -58,6 +74,19 @@ where
}
fn call(&mut self, req: Request) -> Self::Future {
+ // require users to call `poll_ready` first, if they don't we're allowed to panic
+ // as per the `tower::Service` contract
+ assert!(
+ self.grpc_ready,
+ "grpc service not ready. Did you forget to call `poll_ready`?"
+ );
+ assert!(
+ self.rest_ready,
+ "rest service not ready. Did you forget to call `poll_ready`?"
+ );
+
+ // if we get a grpc request call the grpc service, otherwise call the rest service
+ // when calling a service it becomes not-ready so we have drive readiness again
if is_grpc_request(&req) {
self.grpc_ready = false;
let future = self.grpc.call(req);
@@ -78,7 +107,7 @@ where
fn is_grpc_request(req: &Request) -> bool {
req.headers()
- .get("content-type")
+ .get(CONTENT_TYPE)
.map(|content_type| content_type.as_bytes())
.filter(|content_type| content_type.starts_with(b"application/grpc"))
.is_some()