mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-13 19:27:53 +01:00
More error handling
This commit is contained in:
parent
18f613ff98
commit
0e38037c74
10 changed files with 222 additions and 228 deletions
|
@ -2,7 +2,7 @@
|
|||
|
||||
This is *not* https://github.com/carllerche/tower-web even though the name is
|
||||
the same. Its just a prototype of a minimal HTTP framework I've been toying
|
||||
with.
|
||||
with. Will probably change the name to something else.
|
||||
|
||||
# What is this?
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
#![allow(warnings)]
|
||||
|
||||
use bytes::Bytes;
|
||||
use http::{Request, Response, StatusCode};
|
||||
use http::{Request, StatusCode};
|
||||
use hyper::Server;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::SocketAddr,
|
||||
|
@ -14,12 +11,7 @@ use tower::{make::Shared, ServiceBuilder};
|
|||
use tower_http::{
|
||||
add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer,
|
||||
};
|
||||
use tower_web::{
|
||||
body::Body,
|
||||
extract,
|
||||
response::{self, IntoResponse},
|
||||
Error,
|
||||
};
|
||||
use tower_web::{body::Body, extract};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
|
@ -59,7 +51,11 @@ async fn get(
|
|||
_req: Request<Body>,
|
||||
params: extract::UrlParams<(String,)>,
|
||||
state: extract::Extension<SharedState>,
|
||||
) -> Result<Bytes, NotFound> {
|
||||
// Anything that implements `IntoResponse` can be used a response
|
||||
//
|
||||
// Handlers cannot return errors. Everything will be converted
|
||||
// into a response. `BoxError` becomes `500 Internal server error`
|
||||
) -> Result<Bytes, StatusCode> {
|
||||
let state = state.into_inner();
|
||||
let db = &state.lock().unwrap().db;
|
||||
|
||||
|
@ -68,7 +64,7 @@ async fn get(
|
|||
if let Some(value) = db.get(&key) {
|
||||
Ok(value.clone())
|
||||
} else {
|
||||
Err(NotFound)
|
||||
Err(StatusCode::NOT_FOUND)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -77,6 +73,8 @@ async fn set(
|
|||
params: extract::UrlParams<(String,)>,
|
||||
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
|
||||
state: extract::Extension<SharedState>,
|
||||
// `()` also implements `IntoResponse` so we can use that to return
|
||||
// an empty response
|
||||
) {
|
||||
let state = state.into_inner();
|
||||
let db = &mut state.lock().unwrap().db;
|
||||
|
@ -84,16 +82,5 @@ async fn set(
|
|||
let key = params.into_inner();
|
||||
let value = value.into_inner();
|
||||
|
||||
db.insert(key.to_string(), value);
|
||||
}
|
||||
|
||||
struct NotFound;
|
||||
|
||||
impl IntoResponse<Body> for NotFound {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
}
|
||||
db.insert(key, value);
|
||||
}
|
||||
|
|
29
src/body.rs
29
src/body.rs
|
@ -1,4 +1,5 @@
|
|||
use bytes::Buf;
|
||||
use futures_util::ready;
|
||||
use http_body::{Body as _, Empty};
|
||||
use std::{
|
||||
fmt,
|
||||
|
@ -8,6 +9,8 @@ use std::{
|
|||
|
||||
pub use hyper::body::Body;
|
||||
|
||||
use crate::BoxStdError;
|
||||
|
||||
/// A boxed [`Body`] trait object.
|
||||
pub struct BoxBody<D, E> {
|
||||
inner: Pin<Box<dyn http_body::Body<Data = D, Error = E> + Send + Sync + 'static>>,
|
||||
|
@ -42,25 +45,34 @@ impl<D, E> fmt::Debug for BoxBody<D, E> {
|
|||
}
|
||||
}
|
||||
|
||||
// when we've gotten rid of `BoxStdError` then we can remove this
|
||||
impl<D, E> http_body::Body for BoxBody<D, E>
|
||||
where
|
||||
D: Buf,
|
||||
E: Into<tower::BoxError>,
|
||||
{
|
||||
type Data = D;
|
||||
type Error = E;
|
||||
type Error = BoxStdError;
|
||||
|
||||
fn poll_data(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
|
||||
self.inner.as_mut().poll_data(cx)
|
||||
match ready!(self.inner.as_mut().poll_data(cx)) {
|
||||
Some(Ok(chunk)) => Some(Ok(chunk)).into(),
|
||||
Some(Err(err)) => Some(Err(BoxStdError(err.into()))).into(),
|
||||
None => None.into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_trailers(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
|
||||
self.inner.as_mut().poll_trailers(cx)
|
||||
match ready!(self.inner.as_mut().poll_trailers(cx)) {
|
||||
Ok(trailers) => Ok(trailers).into(),
|
||||
Err(err) => Err(BoxStdError(err.into())).into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_end_stream(&self) -> bool {
|
||||
|
@ -71,3 +83,14 @@ where
|
|||
self.inner.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for BoxBody<bytes::Bytes, tower::BoxError> {
|
||||
fn from(s: String) -> Self {
|
||||
let body = hyper::Body::from(s);
|
||||
let body = body.map_err(Into::<tower::BoxError>::into);
|
||||
|
||||
BoxBody {
|
||||
inner: Box::pin(body),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
104
src/error.rs
104
src/error.rs
|
@ -1,104 +0,0 @@
|
|||
use std::convert::Infallible;
|
||||
|
||||
use http::{Response, StatusCode};
|
||||
use tower::BoxError;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum Error {
|
||||
#[error("failed to deserialize the request body")]
|
||||
DeserializeRequestBody(#[source] serde_json::Error),
|
||||
|
||||
#[error("failed to serialize the response body")]
|
||||
SerializeResponseBody(#[source] serde_json::Error),
|
||||
|
||||
#[error("failed to consume the body")]
|
||||
ConsumeRequestBody(#[source] hyper::Error),
|
||||
|
||||
#[error("URI contained no query string")]
|
||||
QueryStringMissing,
|
||||
|
||||
#[error("failed to deserialize query string")]
|
||||
DeserializeQueryString(#[source] serde_urlencoded::de::Error),
|
||||
|
||||
#[error("failed generating the response body")]
|
||||
ResponseBody(#[source] BoxError),
|
||||
|
||||
#[error("some dynamic error happened")]
|
||||
Dynamic(#[source] BoxError),
|
||||
|
||||
#[error("request extension of type `{type_name}` was not set")]
|
||||
MissingExtension { type_name: &'static str },
|
||||
|
||||
#[error("`Content-Length` header is missing but was required")]
|
||||
LengthRequired,
|
||||
|
||||
#[error("response body was too large")]
|
||||
PayloadTooLarge,
|
||||
|
||||
#[error("response failed with status {0}")]
|
||||
Status(StatusCode),
|
||||
|
||||
#[error("invalid URL param. Expected something of type `{type_name}`")]
|
||||
InvalidUrlParam { type_name: &'static str },
|
||||
|
||||
#[error("unknown URL param `{0}`")]
|
||||
UnknownUrlParam(String),
|
||||
|
||||
#[error("response body didn't contain valid UTF-8")]
|
||||
InvalidUtf8,
|
||||
}
|
||||
|
||||
impl From<BoxError> for Error {
|
||||
fn from(err: BoxError) -> Self {
|
||||
match err.downcast::<Error>() {
|
||||
Ok(err) => *err,
|
||||
Err(err) => Error::Dynamic(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Infallible> for Error {
|
||||
fn from(err: Infallible) -> Self {
|
||||
match err {}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_error<B>(error: Error) -> Result<Response<B>, Error>
|
||||
where
|
||||
B: Default,
|
||||
{
|
||||
fn make_response<B>(status: StatusCode) -> Result<Response<B>, Error>
|
||||
where
|
||||
B: Default,
|
||||
{
|
||||
let mut res = Response::new(B::default());
|
||||
*res.status_mut() = status;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
match error {
|
||||
Error::DeserializeRequestBody(_)
|
||||
| Error::QueryStringMissing
|
||||
| Error::DeserializeQueryString(_)
|
||||
| Error::InvalidUrlParam { .. }
|
||||
| Error::InvalidUtf8 => make_response(StatusCode::BAD_REQUEST),
|
||||
|
||||
Error::Status(status) => make_response(status),
|
||||
|
||||
Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED),
|
||||
Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE),
|
||||
|
||||
Error::MissingExtension { .. }
|
||||
| Error::SerializeResponseBody(_)
|
||||
| Error::UnknownUrlParam(_) => make_response(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
|
||||
Error::Dynamic(err) => match err.downcast::<Error>() {
|
||||
Ok(err) => Err(*err),
|
||||
Err(err) => Err(Error::Dynamic(err)),
|
||||
},
|
||||
|
||||
err @ Error::ConsumeRequestBody(_) => Err(err),
|
||||
err @ Error::ResponseBody(_) => Err(err),
|
||||
}
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
use crate::{
|
||||
body::Body,
|
||||
response::{BoxIntoResponse, IntoResponse},
|
||||
Error,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
|
@ -315,21 +314,15 @@ define_rejection! {
|
|||
pub struct UrlParamsMap(HashMap<String, String>);
|
||||
|
||||
impl UrlParamsMap {
|
||||
pub fn get(&self, key: &str) -> Result<&str, Error> {
|
||||
if let Some(value) = self.0.get(key) {
|
||||
Ok(value)
|
||||
} else {
|
||||
Err(Error::UnknownUrlParam(key.to_string()))
|
||||
}
|
||||
pub fn get(&self, key: &str) -> Option<&str> {
|
||||
self.0.get(key).map(|s| &**s)
|
||||
}
|
||||
|
||||
pub fn get_typed<T>(&self, key: &str) -> Result<T, Error>
|
||||
pub fn get_typed<T>(&self, key: &str) -> Option<T>
|
||||
where
|
||||
T: FromStr,
|
||||
{
|
||||
self.get(key)?.parse().map_err(|_| Error::InvalidUrlParam {
|
||||
type_name: std::any::type_name::<T>(),
|
||||
})
|
||||
self.get(key)?.parse().ok()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{body::Body, error::Error, extract::FromRequest, response::IntoResponse};
|
||||
use crate::{body::Body, extract::FromRequest, response::IntoResponse};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::future;
|
||||
use http::{Request, Response};
|
||||
|
@ -8,7 +8,7 @@ use std::{
|
|||
marker::PhantomData,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower::{BoxError, Layer, Service, ServiceExt};
|
||||
use tower::{Layer, Service, ServiceExt};
|
||||
|
||||
mod sealed {
|
||||
pub trait HiddentTrait {}
|
||||
|
@ -30,6 +30,8 @@ pub trait Handler<B, In>: Sized {
|
|||
fn layer<L>(self, layer: L) -> Layered<L::Service, In>
|
||||
where
|
||||
L: Layer<HandlerSvc<Self, B, In>>,
|
||||
<L as Layer<HandlerSvc<Self, B, In>>>::Service: Service<Request<Body>>,
|
||||
<<L as Layer<HandlerSvc<Self, B, In>>>::Service as Service<Request<Body>>>::Error: IntoResponse<B>,
|
||||
{
|
||||
Layered::new(layer.layer(HandlerSvc::new(self)))
|
||||
}
|
||||
|
|
82
src/lib.rs
82
src/lib.rs
|
@ -7,6 +7,7 @@ use futures_util::ready;
|
|||
use http::Response;
|
||||
use pin_project::pin_project;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
|
@ -19,13 +20,9 @@ pub mod handler;
|
|||
pub mod response;
|
||||
pub mod routing;
|
||||
|
||||
mod error;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use self::error::Error;
|
||||
|
||||
pub fn app() -> App<AlwaysNotFound> {
|
||||
App {
|
||||
service_tree: AlwaysNotFound(()),
|
||||
|
@ -52,7 +49,6 @@ impl<R> App<R> {
|
|||
|
||||
pub struct IntoService<R> {
|
||||
app: App<R>,
|
||||
poll_ready_error: Option<Error>,
|
||||
}
|
||||
|
||||
impl<R> Clone for IntoService<R>
|
||||
|
@ -62,71 +58,67 @@ where
|
|||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
app: self.app.clone(),
|
||||
poll_ready_error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, B, T> Service<T> for IntoService<R>
|
||||
where
|
||||
R: Service<T, Response = Response<B>>,
|
||||
R::Error: Into<Error>,
|
||||
R: Service<T, Response = Response<B>, Error = Infallible>,
|
||||
B: Default,
|
||||
{
|
||||
type Response = Response<B>;
|
||||
type Error = Error;
|
||||
type Future = HandleErrorFuture<R::Future, B>;
|
||||
type Error = Infallible;
|
||||
type Future = HandleErrorFuture<R::Future>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
if let Err(err) = ready!(self.app.service_tree.poll_ready(cx)).map_err(Into::into) {
|
||||
self.poll_ready_error = Some(err);
|
||||
match ready!(self.app.service_tree.poll_ready(cx)) {
|
||||
Ok(_) => Poll::Ready(Ok(())),
|
||||
Err(err) => match err {},
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: T) -> Self::Future {
|
||||
if let Some(poll_ready_error) = self.poll_ready_error.take() {
|
||||
match error::handle_error::<B>(poll_ready_error) {
|
||||
Ok(res) => {
|
||||
return HandleErrorFuture(Kind::Response(Some(res)));
|
||||
}
|
||||
Err(err) => {
|
||||
return HandleErrorFuture(Kind::Error(Some(err)));
|
||||
}
|
||||
}
|
||||
}
|
||||
HandleErrorFuture(Kind::Future(self.app.service_tree.call(req)))
|
||||
HandleErrorFuture(self.app.service_tree.call(req))
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
pub struct HandleErrorFuture<F, B>(#[pin] Kind<F, B>);
|
||||
pub struct HandleErrorFuture<F>(#[pin] F);
|
||||
|
||||
#[pin_project(project = KindProj)]
|
||||
enum Kind<F, B> {
|
||||
Response(Option<Response<B>>),
|
||||
Error(Option<Error>),
|
||||
Future(#[pin] F),
|
||||
}
|
||||
|
||||
impl<F, B, E> Future for HandleErrorFuture<F, B>
|
||||
impl<F, B> Future for HandleErrorFuture<F>
|
||||
where
|
||||
F: Future<Output = Result<Response<B>, E>>,
|
||||
E: Into<Error>,
|
||||
F: Future<Output = Result<Response<B>, Infallible>>,
|
||||
B: Default,
|
||||
{
|
||||
type Output = Result<Response<B>, Error>;
|
||||
type Output = Result<Response<B>, Infallible>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.project().0.project() {
|
||||
KindProj::Response(res) => Poll::Ready(Ok(res.take().unwrap())),
|
||||
KindProj::Error(err) => Poll::Ready(Err(err.take().unwrap())),
|
||||
KindProj::Future(fut) => match ready!(fut.poll(cx)) {
|
||||
Ok(res) => Poll::Ready(Ok(res)),
|
||||
Err(err) => Poll::Ready(error::handle_error(err.into())),
|
||||
},
|
||||
self.project().0.poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait ResultExt<T> {
|
||||
fn unwrap_infallible(self) -> T;
|
||||
}
|
||||
|
||||
impl<T> ResultExt<T> for Result<T, Infallible> {
|
||||
fn unwrap_infallible(self) -> T {
|
||||
match self {
|
||||
Ok(value) => value,
|
||||
Err(err) => match err {},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// work around for `BoxError` not implementing `std::error::Error`
|
||||
//
|
||||
// This is currently required since tower-http's Compression middleware's body type's
|
||||
// error only implements error when the inner error type does:
|
||||
// https://github.com/tower-rs/tower-http/blob/master/tower-http/src/lib.rs#L310
|
||||
//
|
||||
// Fixing that is a breaking change to tower-http so we should wait a bit, but should
|
||||
// totally fix it at some point.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{0}")]
|
||||
pub struct BoxStdError(#[source] tower::BoxError);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::Body;
|
||||
use bytes::Bytes;
|
||||
use http::{header, HeaderValue, Response, StatusCode};
|
||||
use http::{HeaderMap, HeaderValue, Response, StatusCode, header};
|
||||
use serde::Serialize;
|
||||
use std::convert::Infallible;
|
||||
use tower::{util::Either, BoxError};
|
||||
|
@ -176,9 +176,45 @@ impl<B> IntoResponse<B> for BoxIntoResponse<B> {
|
|||
|
||||
impl IntoResponse<Body> for BoxError {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
// TODO(david): test for know error types like std::io::Error
|
||||
// or common errors types from tower and map those more appropriately
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::from(self.to_string()))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> IntoResponse<B> for StatusCode
|
||||
where
|
||||
B: Default,
|
||||
{
|
||||
fn into_response(self) -> Response<B> {
|
||||
Response::builder().status(self).body(B::default()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoResponse<Body> for (StatusCode, T)
|
||||
where
|
||||
T: Into<Body>,
|
||||
{
|
||||
fn into_response(self) -> Response<Body> {
|
||||
Response::builder()
|
||||
.status(self.0)
|
||||
.body(self.1.into())
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoResponse<Body> for (StatusCode, HeaderMap, T)
|
||||
where
|
||||
T: Into<Body>,
|
||||
{
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let mut res = Response::new(self.2.into());
|
||||
*res.status_mut() = self.0;
|
||||
*res.headers_mut() = self.1;
|
||||
res
|
||||
}
|
||||
}
|
||||
|
|
137
src/routing.rs
137
src/routing.rs
|
@ -1,8 +1,7 @@
|
|||
use crate::{
|
||||
body::{Body, BoxBody},
|
||||
error::Error,
|
||||
handler::{Handler, HandlerSvc},
|
||||
App, IntoService,
|
||||
App, IntoService, ResultExt,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{future, ready};
|
||||
|
@ -164,17 +163,18 @@ impl<R> RouteBuilder<R> {
|
|||
}
|
||||
|
||||
pub fn into_service(self) -> IntoService<R> {
|
||||
IntoService {
|
||||
app: self.app,
|
||||
poll_ready_error: None,
|
||||
}
|
||||
IntoService { app: self.app }
|
||||
}
|
||||
|
||||
// TODO(david): Add `layer` method here that applies a `tower::Layer` inside the service tree
|
||||
// that way we get to map errors
|
||||
|
||||
pub fn boxed<B>(self) -> RouteBuilder<BoxServiceTree<B>>
|
||||
where
|
||||
R: Service<Request<Body>, Response = Response<B>, Error = Error> + Send + 'static,
|
||||
R: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Send + 'static,
|
||||
R::Future: Send,
|
||||
B: Default + 'static,
|
||||
// TODO(david): do we still need default here
|
||||
B: Default + From<String> + 'static,
|
||||
{
|
||||
let svc = ServiceBuilder::new()
|
||||
.layer(BufferLayer::new(1024))
|
||||
|
@ -182,7 +182,10 @@ impl<R> RouteBuilder<R> {
|
|||
.service(self.app.service_tree);
|
||||
|
||||
let app = App {
|
||||
service_tree: BoxServiceTree { inner: svc },
|
||||
service_tree: BoxServiceTree {
|
||||
inner: svc,
|
||||
poll_ready_error: None,
|
||||
},
|
||||
};
|
||||
|
||||
RouteBuilder {
|
||||
|
@ -270,29 +273,27 @@ impl RouteSpec {
|
|||
|
||||
impl<H, F, HB, FB> Service<Request<Body>> for Or<H, F>
|
||||
where
|
||||
H: Service<Request<Body>, Response = Response<HB>>,
|
||||
H::Error: Into<Error>,
|
||||
H: Service<Request<Body>, Response = Response<HB>, Error = Infallible>,
|
||||
HB: http_body::Body + Send + Sync + 'static,
|
||||
HB::Error: Into<BoxError>,
|
||||
|
||||
F: Service<Request<Body>, Response = Response<FB>>,
|
||||
F::Error: Into<Error>,
|
||||
F: Service<Request<Body>, Response = Response<FB>, Error = Infallible>,
|
||||
FB: http_body::Body<Data = HB::Data> + Send + Sync + 'static,
|
||||
FB::Error: Into<BoxError>,
|
||||
{
|
||||
type Response = Response<BoxBody<HB::Data, Error>>;
|
||||
type Error = Error;
|
||||
type Response = Response<BoxBody<HB::Data, BoxError>>;
|
||||
type Error = Infallible;
|
||||
type Future = future::Either<BoxResponseBody<H::Future>, BoxResponseBody<F::Future>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
loop {
|
||||
if !self.handler_ready {
|
||||
ready!(self.service.poll_ready(cx)).map_err(Into::into)?;
|
||||
ready!(self.service.poll_ready(cx)).unwrap_infallible();
|
||||
self.handler_ready = true;
|
||||
}
|
||||
|
||||
if !self.fallback_ready {
|
||||
ready!(self.fallback.poll_ready(cx)).map_err(Into::into)?;
|
||||
ready!(self.fallback.poll_ready(cx)).unwrap_infallible();
|
||||
self.fallback_ready = true;
|
||||
}
|
||||
|
||||
|
@ -333,20 +334,18 @@ pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>);
|
|||
#[pin_project]
|
||||
pub struct BoxResponseBody<F>(#[pin] F);
|
||||
|
||||
impl<F, B, E> Future for BoxResponseBody<F>
|
||||
impl<F, B> Future for BoxResponseBody<F>
|
||||
where
|
||||
F: Future<Output = Result<Response<B>, E>>,
|
||||
E: Into<Error>,
|
||||
F: Future<Output = Result<Response<B>, Infallible>>,
|
||||
B: http_body::Body + Send + Sync + 'static,
|
||||
B::Error: Into<BoxError>,
|
||||
{
|
||||
type Output = Result<Response<BoxBody<B::Data, Error>>, Error>;
|
||||
type Output = Result<Response<BoxBody<B::Data, BoxError>>, Infallible>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let response: Response<B> = ready!(self.project().0.poll(cx)).map_err(Into::into)?;
|
||||
let response: Response<B> = ready!(self.project().0.poll(cx)).unwrap_infallible();
|
||||
let response = response.map(|body| {
|
||||
// TODO(david): attempt to downcast this into `Error`
|
||||
let body = body.map_err(|err| Error::ResponseBody(err.into()));
|
||||
let body = body.map_err(Into::into);
|
||||
BoxBody::new(body)
|
||||
});
|
||||
Poll::Ready(Ok(response))
|
||||
|
@ -354,13 +353,15 @@ where
|
|||
}
|
||||
|
||||
pub struct BoxServiceTree<B> {
|
||||
inner: Buffer<BoxService<Request<Body>, Response<B>, Error>, Request<Body>>,
|
||||
inner: Buffer<BoxService<Request<Body>, Response<B>, Infallible>, Request<Body>>,
|
||||
poll_ready_error: Option<BoxError>,
|
||||
}
|
||||
|
||||
impl<B> Clone for BoxServiceTree<B> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: self.inner.clone(),
|
||||
poll_ready_error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -373,21 +374,36 @@ impl<B> fmt::Debug for BoxServiceTree<B> {
|
|||
|
||||
impl<B> Service<Request<Body>> for BoxServiceTree<B>
|
||||
where
|
||||
B: 'static,
|
||||
B: From<String> + 'static,
|
||||
{
|
||||
type Response = Response<B>;
|
||||
type Error = Error;
|
||||
type Error = Infallible;
|
||||
type Future = BoxServiceTreeResponseFuture<B>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx).map_err(Error::from)
|
||||
// TODO(david): downcast this into one of the cases in `tower::buffer::error`
|
||||
// and convert the error into a response. `ServiceError` should never be able to happen
|
||||
// since all inner services use `Infallible` as the error type.
|
||||
match ready!(self.inner.poll_ready(cx)) {
|
||||
Ok(_) => Poll::Ready(Ok(())),
|
||||
Err(err) => {
|
||||
self.poll_ready_error = Some(err);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
if let Some(err) = self.poll_ready_error.take() {
|
||||
return BoxServiceTreeResponseFuture {
|
||||
kind: Kind::Response(Some(handle_buffer_error(err))),
|
||||
};
|
||||
}
|
||||
|
||||
BoxServiceTreeResponseFuture {
|
||||
inner: self.inner.call(req),
|
||||
kind: Kind::Future(self.inner.call(req)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -395,24 +411,71 @@ where
|
|||
#[pin_project]
|
||||
pub struct BoxServiceTreeResponseFuture<B> {
|
||||
#[pin]
|
||||
inner: InnerFuture<B>,
|
||||
kind: Kind<B>,
|
||||
}
|
||||
|
||||
#[pin_project(project = KindProj)]
|
||||
enum Kind<B> {
|
||||
Response(Option<Response<B>>),
|
||||
Future(#[pin] InnerFuture<B>),
|
||||
}
|
||||
|
||||
type InnerFuture<B> = tower::buffer::future::ResponseFuture<
|
||||
Pin<Box<dyn Future<Output = Result<Response<B>, Error>> + Send + 'static>>,
|
||||
Pin<Box<dyn Future<Output = Result<Response<B>, Infallible>> + Send + 'static>>,
|
||||
>;
|
||||
|
||||
impl<B> Future for BoxServiceTreeResponseFuture<B> {
|
||||
type Output = Result<Response<B>, Error>;
|
||||
impl<B> Future for BoxServiceTreeResponseFuture<B>
|
||||
where
|
||||
B: From<String>,
|
||||
{
|
||||
type Output = Result<Response<B>, Infallible>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.project()
|
||||
.inner
|
||||
.poll(cx)
|
||||
.map_err(Error::from)
|
||||
match self.project().kind.project() {
|
||||
KindProj::Response(res) => Poll::Ready(Ok(res.take().unwrap())),
|
||||
KindProj::Future(future) => match ready!(future.poll(cx)) {
|
||||
Ok(res) => Poll::Ready(Ok(res)),
|
||||
Err(err) => Poll::Ready(Ok(handle_buffer_error(err))),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_buffer_error<B>(error: BoxError) -> Response<B>
|
||||
where
|
||||
B: From<String>,
|
||||
{
|
||||
use tower::buffer::error::{Closed, ServiceError};
|
||||
|
||||
let error = match error.downcast::<Closed>() {
|
||||
Ok(closed) => {
|
||||
return Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(B::from(closed.to_string()))
|
||||
.unwrap();
|
||||
}
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
let error = match error.downcast::<ServiceError>() {
|
||||
Ok(service_error) => {
|
||||
return Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(B::from(format!("Service error: {}. This is a bug in tower-web. All inner services should be infallible. Please file an issue", service_error)))
|
||||
.unwrap();
|
||||
}
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(B::from(format!(
|
||||
"Uncountered an unknown error: {}. This should never happen. Please file an issue",
|
||||
error
|
||||
)))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[allow(unused_imports)]
|
||||
|
|
|
@ -273,6 +273,8 @@ async fn boxing() {
|
|||
assert_eq!(res.text().await.unwrap(), "hi from POST");
|
||||
}
|
||||
|
||||
// TODO(david): tests for adding middleware to single services
|
||||
|
||||
/// Run a `tower::Service` in the background and get a URI for it.
|
||||
pub async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
|
||||
where
|
||||
|
|
Loading…
Add table
Reference in a new issue