mirror of
https://github.com/tokio-rs/axum.git
synced 2024-10-23 17:36:39 +02:00
Support any type of response body
This commit is contained in:
parent
5efa7ab7ea
commit
a04c98dd42
2 changed files with 68 additions and 9 deletions
|
@ -21,3 +21,6 @@ tower = { version = "0.4", features = ["util"] }
|
|||
[dev-dependencies]
|
||||
tokio = { version = "1.6.1", features = ["macros", "rt"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tower = { version = "0.4", features = ["util", "make"] }
|
||||
tower-http = { version = "0.1", features = ["trace"] }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
|
|
74
src/lib.rs
74
src/lib.rs
|
@ -31,6 +31,7 @@ use async_trait::async_trait;
|
|||
use bytes::Bytes;
|
||||
use futures_util::{future, ready};
|
||||
use http::{Method, Request, Response, StatusCode};
|
||||
use http_body::{combinators::BoxBody, Body as _};
|
||||
use pin_project::pin_project;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::{
|
||||
|
@ -39,7 +40,7 @@ use std::{
|
|||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower::{Service, ServiceExt};
|
||||
use tower::{BoxError, Service, ServiceExt};
|
||||
|
||||
pub use hyper::body::Body;
|
||||
|
||||
|
@ -155,6 +156,9 @@ pub enum Error {
|
|||
|
||||
#[error("failed to deserialize query string")]
|
||||
DeserializeQueryString(#[from] serde_urlencoded::de::Error),
|
||||
|
||||
#[error("failed generating the response body")]
|
||||
ResponseBody(#[source] BoxError),
|
||||
}
|
||||
|
||||
// TODO(david): make this trait sealed
|
||||
|
@ -408,14 +412,19 @@ impl RouteSpec {
|
|||
}
|
||||
}
|
||||
|
||||
impl<H, F> Service<Request<Body>> for Route<H, F>
|
||||
impl<H, F, HB, FB> Service<Request<Body>> for Route<H, F>
|
||||
where
|
||||
H: Service<Request<Body>, Response = Response<Body>, Error = Error>,
|
||||
F: Service<Request<Body>, Response = Response<Body>, Error = Error>,
|
||||
H: Service<Request<Body>, Response = Response<HB>, Error = Error>,
|
||||
F: Service<Request<Body>, Response = Response<FB>, Error = Error>,
|
||||
HB: http_body::Body + Send + Sync + 'static,
|
||||
HB::Error: Into<BoxError>,
|
||||
FB: http_body::Body<Data = HB::Data> + Send + Sync + 'static,
|
||||
FB::Error: Into<BoxError>,
|
||||
{
|
||||
type Response = Response<Body>;
|
||||
type Response = Response<BoxBody<HB::Data, Error>>;
|
||||
type Error = Error;
|
||||
type Future = future::Either<H::Future, F::Future>;
|
||||
// type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
type Future = future::Either<BoxResponseBody<H::Future>, BoxResponseBody<F::Future>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
loop {
|
||||
|
@ -442,18 +451,38 @@ where
|
|||
"handler not ready. Did you forget to call `poll_ready`?"
|
||||
);
|
||||
self.handler_ready = false;
|
||||
future::Either::Left(self.handler.call(req))
|
||||
future::Either::Left(BoxResponseBody(self.handler.call(req)))
|
||||
} else {
|
||||
assert!(
|
||||
self.fallback_ready,
|
||||
"fallback not ready. Did you forget to call `poll_ready`?"
|
||||
);
|
||||
self.fallback_ready = false;
|
||||
future::Either::Right(self.fallback.call(req))
|
||||
// TODO(david): this leads to each route creating one box body, probably not great
|
||||
future::Either::Right(BoxResponseBody(self.fallback.call(req)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
pub struct BoxResponseBody<F>(#[pin] F);
|
||||
|
||||
impl<F, B> Future for BoxResponseBody<F>
|
||||
where
|
||||
F: Future<Output = Result<Response<B>, Error>>,
|
||||
B: http_body::Body + Send + Sync + 'static,
|
||||
B::Error: Into<BoxError>,
|
||||
{
|
||||
type Output = Result<Response<BoxBody<B::Data, Error>>, Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let response: Response<B> = ready!(self.project().0.poll(cx))?;
|
||||
let response =
|
||||
response.map(|body| body.map_err(|err| Error::ResponseBody(err.into())).boxed());
|
||||
Poll::Ready(Ok(response))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, T> Service<T> for App<R>
|
||||
where
|
||||
R: Service<T>,
|
||||
|
@ -496,6 +525,10 @@ where
|
|||
mod tests {
|
||||
#![allow(warnings)]
|
||||
use super::*;
|
||||
use hyper::Server;
|
||||
use std::{fmt, net::SocketAddr};
|
||||
use tower::{make::Shared, ServiceBuilder};
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
#[tokio::test]
|
||||
async fn basic() {
|
||||
|
@ -517,7 +550,30 @@ mod tests {
|
|||
dbg!(&body);
|
||||
}
|
||||
|
||||
async fn body_to_string(res: Response<Body>) -> String {
|
||||
#[allow(dead_code)]
|
||||
// this should just compile
|
||||
async fn compatible_with_hyper_and_tower_http() {
|
||||
let app = app()
|
||||
.at("/")
|
||||
.get(root)
|
||||
.at("/users")
|
||||
.get(users_index)
|
||||
.post(users_create);
|
||||
|
||||
let app = ServiceBuilder::new()
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.service(app);
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
let server = Server::bind(&addr).serve(Shared::new(app));
|
||||
server.await.unwrap();
|
||||
}
|
||||
|
||||
async fn body_to_string<B>(res: Response<B>) -> String
|
||||
where
|
||||
B: http_body::Body,
|
||||
B::Error: fmt::Debug,
|
||||
{
|
||||
let bytes = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
String::from_utf8(bytes.to_vec()).unwrap()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue