mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-31 11:49:55 +02:00
Fix layer support
This commit is contained in:
parent
f983b37fea
commit
b763eaa037
3 changed files with 244 additions and 66 deletions
|
@ -22,5 +22,5 @@ tower = { version = "0.4", features = ["util"] }
|
|||
tokio = { version = "1.6.1", features = ["macros", "rt"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tower = { version = "0.4", features = ["util", "make", "timeout"] }
|
||||
tower-http = { version = "0.1", features = ["trace"] }
|
||||
tower-http = { version = "0.1", features = ["trace", "compression"] }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
|
|
71
src/body.rs
Normal file
71
src/body.rs
Normal file
|
@ -0,0 +1,71 @@
|
|||
use bytes::Buf;
|
||||
use http_body::{Body, Empty};
|
||||
use std::{
|
||||
fmt,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// A boxed [`Body`] trait object.
|
||||
pub struct BoxBody<D, E> {
|
||||
inner: Pin<Box<dyn Body<Data = D, Error = E> + Send + Sync + 'static>>,
|
||||
}
|
||||
|
||||
impl<D, E> BoxBody<D, E> {
|
||||
/// Create a new `BoxBody`.
|
||||
pub fn new<B>(body: B) -> Self
|
||||
where
|
||||
B: Body<Data = D, Error = E> + Send + Sync + 'static,
|
||||
D: Buf,
|
||||
{
|
||||
Self {
|
||||
inner: Box::pin(body),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(david): upstream this to http-body?
|
||||
impl<D, E> Default for BoxBody<D, E>
|
||||
where
|
||||
D: bytes::Buf + 'static,
|
||||
{
|
||||
fn default() -> Self {
|
||||
BoxBody::new(Empty::<D>::new().map_err(|err| match err {}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, E> fmt::Debug for BoxBody<D, E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("BoxBody").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, E> Body for BoxBody<D, E>
|
||||
where
|
||||
D: Buf,
|
||||
{
|
||||
type Data = D;
|
||||
type Error = E;
|
||||
|
||||
fn poll_data(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
|
||||
self.inner.as_mut().poll_data(cx)
|
||||
}
|
||||
|
||||
fn poll_trailers(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
|
||||
self.inner.as_mut().poll_trailers(cx)
|
||||
}
|
||||
|
||||
fn is_end_stream(&self) -> bool {
|
||||
self.inner.is_end_stream()
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> http_body::SizeHint {
|
||||
self.inner.size_hint()
|
||||
}
|
||||
}
|
237
src/lib.rs
237
src/lib.rs
|
@ -7,18 +7,6 @@ Improvements to make:
|
|||
Somehow return generic "into response" kinda types without having to manually
|
||||
create hyper::Body for everything
|
||||
|
||||
Don't make Query and Json contain a Result, instead make generic wrapper
|
||||
for "optional" inputs
|
||||
|
||||
Make it possible to convert QueryError and JsonError into responses
|
||||
|
||||
Support wrapping single routes in tower::Layer
|
||||
|
||||
Support putting a tower::Service at a Route
|
||||
|
||||
Don't require the response body to be hyper::Body, wont work if we're wrapping
|
||||
single routes in layers
|
||||
|
||||
Support extracting headers, perhaps via `headers::Header`?
|
||||
|
||||
Implement `FromRequest` for more functions, with macro
|
||||
|
@ -31,7 +19,7 @@ use async_trait::async_trait;
|
|||
use bytes::Bytes;
|
||||
use futures_util::{future, ready};
|
||||
use http::{Method, Request, Response, StatusCode};
|
||||
use http_body::{combinators::BoxBody, Body as _};
|
||||
use http_body::Body as _;
|
||||
use pin_project::pin_project;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::{
|
||||
|
@ -43,6 +31,9 @@ use std::{
|
|||
};
|
||||
use tower::{BoxError, Layer, Service, ServiceExt};
|
||||
|
||||
mod body;
|
||||
pub use body::BoxBody;
|
||||
|
||||
pub use hyper::body::Body;
|
||||
|
||||
pub fn app() -> App<EmptyRouter> {
|
||||
|
@ -134,12 +125,23 @@ impl<R> RouteAt<R> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RouteBuilder<R> {
|
||||
app: App<R>,
|
||||
route_spec: Bytes,
|
||||
}
|
||||
|
||||
impl<R> Clone for RouteBuilder<R>
|
||||
where
|
||||
R: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
app: self.app.clone(),
|
||||
route_spec: self.route_spec.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> RouteBuilder<R> {
|
||||
pub fn at(self, route_spec: &str) -> RouteAt<R> {
|
||||
self.app.at(route_spec)
|
||||
|
@ -174,6 +176,13 @@ impl<R> RouteBuilder<R> {
|
|||
{
|
||||
self.app.at_bytes(self.route_spec).post_service(service)
|
||||
}
|
||||
|
||||
pub fn into_service(self) -> IntoService<R> {
|
||||
IntoService {
|
||||
app: self.app,
|
||||
poll_ready_error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
@ -183,13 +192,13 @@ pub enum Error {
|
|||
DeserializeRequestBody(#[source] serde_json::Error),
|
||||
|
||||
#[error("failed to consume the body")]
|
||||
ConsumeBody(#[source] hyper::Error),
|
||||
ConsumeRequestBody(#[source] hyper::Error),
|
||||
|
||||
#[error("URI contained no query string")]
|
||||
QueryStringMissing,
|
||||
|
||||
#[error("failed to deserialize query string")]
|
||||
DeserializeQueryString(#[from] serde_urlencoded::de::Error),
|
||||
DeserializeQueryString(#[source] serde_urlencoded::de::Error),
|
||||
|
||||
#[error("failed generating the response body")]
|
||||
ResponseBody(#[source] BoxError),
|
||||
|
@ -198,15 +207,6 @@ pub enum Error {
|
|||
Service(#[source] BoxError),
|
||||
}
|
||||
|
||||
impl From<BoxError> for Error {
|
||||
fn from(err: BoxError) -> Self {
|
||||
match err.downcast::<Error>() {
|
||||
Ok(err) => *err,
|
||||
Err(err) => Error::Service(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Infallible> for Error {
|
||||
fn from(err: Infallible) -> Self {
|
||||
match err {}
|
||||
|
@ -316,7 +316,7 @@ where
|
|||
self.svc
|
||||
.oneshot(req)
|
||||
.await
|
||||
.map_err(|err| Error::from(err.into()))
|
||||
.map_err(|err| Error::Service(err.into()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -425,7 +425,7 @@ where
|
|||
fn from_request(req: &mut Request<Body>) -> Self::Future {
|
||||
let result = (|| {
|
||||
let query = req.uri().query().ok_or(Error::QueryStringMissing)?;
|
||||
let value = serde_urlencoded::from_str(query)?;
|
||||
let value = serde_urlencoded::from_str(query).map_err(Error::DeserializeQueryString)?;
|
||||
Ok(Query(value))
|
||||
})();
|
||||
|
||||
|
@ -456,7 +456,7 @@ where
|
|||
Box::pin(async move {
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeBody)?;
|
||||
.map_err(Error::ConsumeRequestBody)?;
|
||||
let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?;
|
||||
Ok(Json(value))
|
||||
})
|
||||
|
@ -588,53 +588,116 @@ where
|
|||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let response: Response<B> = ready!(self.project().0.poll(cx)).map_err(Into::into)?;
|
||||
let response =
|
||||
response.map(|body| body.map_err(|err| Error::ResponseBody(err.into())).boxed());
|
||||
let response = response.map(|body| {
|
||||
let body = body.map_err(|err| Error::ResponseBody(err.into()));
|
||||
BoxBody::new(body)
|
||||
});
|
||||
Poll::Ready(Ok(response))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, T> Service<T> for App<R>
|
||||
pub struct IntoService<R> {
|
||||
app: App<R>,
|
||||
poll_ready_error: Option<Error>,
|
||||
}
|
||||
|
||||
impl<R> Clone for IntoService<R>
|
||||
where
|
||||
R: Service<T>,
|
||||
R::Error: Into<Error>,
|
||||
R: Clone,
|
||||
{
|
||||
type Response = R::Response;
|
||||
type Error = R::Error;
|
||||
type Future = R::Future;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// TODO(david): map error to response
|
||||
self.router.poll_ready(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn call(&mut self, req: T) -> Self::Future {
|
||||
// TODO(david): map error to response
|
||||
self.router.call(req)
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
app: self.app.clone(),
|
||||
poll_ready_error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, T> Service<T> for RouteBuilder<R>
|
||||
impl<R, B, T> Service<T> for IntoService<R>
|
||||
where
|
||||
App<R>: Service<T>,
|
||||
<App<R> as Service<T>>::Error: Into<Error>,
|
||||
R: Service<T, Response = Response<B>>,
|
||||
R::Error: Into<Error>,
|
||||
B: Default,
|
||||
{
|
||||
type Response = <App<R> as Service<T>>::Response;
|
||||
type Error = <App<R> as Service<T>>::Error;
|
||||
type Future = <App<R> as Service<T>>::Future;
|
||||
type Response = Response<B>;
|
||||
type Error = Error;
|
||||
type Future = HandleErrorFuture<R::Future, B>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// TODO(david): map error to response
|
||||
self.app.poll_ready(cx)
|
||||
self.app.router.poll_ready(cx).map_err(Into::into)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn call(&mut self, req: T) -> Self::Future {
|
||||
// TODO(david): map error to response
|
||||
self.app.call(req)
|
||||
if let Some(poll_ready_error) = self.poll_ready_error.take() {
|
||||
match handle_error::<B>(poll_ready_error) {
|
||||
Ok(res) => {
|
||||
return HandleErrorFuture(Kind::Response(Some(res)));
|
||||
}
|
||||
Err(err) => {
|
||||
return HandleErrorFuture(Kind::Error(Some(err)));
|
||||
}
|
||||
}
|
||||
}
|
||||
HandleErrorFuture(Kind::Future(self.app.router.call(req)))
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
pub struct HandleErrorFuture<F, B>(#[pin] Kind<F, B>);
|
||||
|
||||
#[pin_project(project = KindProj)]
|
||||
enum Kind<F, B> {
|
||||
Response(Option<Response<B>>),
|
||||
Error(Option<Error>),
|
||||
Future(#[pin] F),
|
||||
}
|
||||
|
||||
impl<F, B, E> Future for HandleErrorFuture<F, B>
|
||||
where
|
||||
F: Future<Output = Result<Response<B>, E>>,
|
||||
E: Into<Error>,
|
||||
B: Default,
|
||||
{
|
||||
type Output = Result<Response<B>, Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.project().0.project() {
|
||||
KindProj::Response(res) => Poll::Ready(Ok(res.take().unwrap())),
|
||||
KindProj::Error(err) => Poll::Ready(Err(err.take().unwrap())),
|
||||
KindProj::Future(fut) => match ready!(fut.poll(cx)) {
|
||||
Ok(res) => Poll::Ready(Ok(res)),
|
||||
Err(err) => Poll::Ready(handle_error(err.into())),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_error<B>(error: Error) -> Result<Response<B>, Error>
|
||||
where
|
||||
B: Default,
|
||||
{
|
||||
fn make_response<B>(status: StatusCode) -> Result<Response<B>, Error>
|
||||
where
|
||||
B: Default,
|
||||
{
|
||||
let mut res = Response::new(B::default());
|
||||
*res.status_mut() = status;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
match error {
|
||||
Error::DeserializeRequestBody(_)
|
||||
| Error::QueryStringMissing
|
||||
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
|
||||
|
||||
Error::Service(err) => match err.downcast::<Error>() {
|
||||
Ok(err) => Err(*err),
|
||||
Err(err) => Err(Error::Service(err)),
|
||||
},
|
||||
|
||||
err @ Error::ConsumeRequestBody(_) => Err(err),
|
||||
err @ Error::ResponseBody(_) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -648,7 +711,10 @@ mod tests {
|
|||
use tower::{
|
||||
layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder,
|
||||
};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_http::{
|
||||
compression::CompressionLayer,
|
||||
trace::{Trace, TraceLayer},
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn basic() {
|
||||
|
@ -667,9 +733,15 @@ mod tests {
|
|||
Ok(Response::new(Body::from("Hello, World!")))
|
||||
}
|
||||
|
||||
let mut app = app()
|
||||
async fn large_static_file(_: Request<Body>) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::empty()))
|
||||
}
|
||||
|
||||
let app = app()
|
||||
// routes with functions
|
||||
.at("/")
|
||||
.get(root.layer(TimeoutLayer::new(Duration::from_secs(30))))
|
||||
.get(root)
|
||||
// routes with closures
|
||||
.at("/users")
|
||||
.get(|_: Request<Body>, pagination: Query<Pagination>| async {
|
||||
let pagination = pagination.into_inner();
|
||||
|
@ -684,8 +756,23 @@ mod tests {
|
|||
|
||||
Ok::<_, Error>(Response::new(Body::from("users#create")))
|
||||
})
|
||||
// routes with a service
|
||||
.at("/service")
|
||||
.get_service(service_fn(root));
|
||||
.get_service(service_fn(root))
|
||||
// routes with layers applied
|
||||
.at("/large-static-file")
|
||||
.get(
|
||||
large_static_file.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(30)))
|
||||
.layer(CompressionLayer::new())
|
||||
.into_inner(),
|
||||
),
|
||||
)
|
||||
.into_service();
|
||||
|
||||
// can add more middleware
|
||||
let mut app = Trace::new_for_http(app);
|
||||
|
||||
let res = app
|
||||
.ready()
|
||||
|
@ -719,6 +806,22 @@ mod tests {
|
|||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(body_to_string(res).await, "users#index");
|
||||
|
||||
let res = app
|
||||
.ready()
|
||||
.await
|
||||
.unwrap()
|
||||
.call(
|
||||
Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/users")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
assert_eq!(body_to_string(res).await, "");
|
||||
|
||||
let res = app
|
||||
.ready()
|
||||
.await
|
||||
|
@ -748,12 +851,16 @@ mod tests {
|
|||
#[allow(dead_code)]
|
||||
// this should just compile
|
||||
async fn compatible_with_hyper_and_tower_http() {
|
||||
let app = app().at("/").get(|_: Request<Body>| async {
|
||||
Ok::<_, Error>(Response::new(Body::from("Hello, World!")))
|
||||
});
|
||||
let app = app()
|
||||
.at("/")
|
||||
.get(|_: Request<Body>| async {
|
||||
Ok::<_, Error>(Response::new(Body::from("Hello, World!")))
|
||||
})
|
||||
.into_service();
|
||||
|
||||
let app = ServiceBuilder::new()
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(CompressionLayer::new())
|
||||
.service(app);
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
|
|
Loading…
Add table
Reference in a new issue