Move methods from ServiceExt to RoutingDsl (#160)

Previously, on `main`, this wouldn't compile:

```rust
let app = route("/", get(handler))
    .layer(
        ServiceBuilder::new()
            .timeout(Duration::from_secs(10))
            .into_inner(),
    )
    .handle_error(...)
    .route(...); // <-- doesn't work
```

That is because `handle_error` would be
`axum::service::ServiceExt::handle_error` which returns `HandleError<_,
_, _, HandleErrorFromService>` which does _not_ implement `RoutingDsl`.
So you couldn't call `route`. This was caused by
https://github.com/tokio-rs/axum/pull/120.

Basically `handle_error` when called on a `RoutingDsl`, the resulting
service should also implement `RoutingDsl`, but if called on another
random service it should _not_ implement `RoutingDsl`.

I don't think thats possible by having `handle_error` on `ServiceExt`
which is implemented for any service, since all axum routers are also
services by design.

This resolves the issue by removing `ServiceExt` and moving its methods
to `RoutingDsl`. Then we have more tight control over what has a
`handle_error` method.

`service::OnMethod` now also has a `handle_error` so you can still
handle errors from random services, by doing
`service::any(svc).handle_error(...)`.
This commit is contained in:
David Pedersen 2021-08-08 14:30:51 +02:00 committed by GitHub
parent 9b3f3c9bdf
commit 8013165908
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 346 additions and 297 deletions

View file

@ -9,8 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Make `FromRequest` default to being generic over `body::Body` ([#146](https://github.com/tokio-rs/axum/pull/146)) - Make `FromRequest` default to being generic over `body::Body` ([#146](https://github.com/tokio-rs/axum/pull/146))
- Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153)) - Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153))
- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156))
- Add `RoutingDsl::or` for combining routes ([#108](https://github.com/tokio-rs/axum/pull/108)) - Add `RoutingDsl::or` for combining routes ([#108](https://github.com/tokio-rs/axum/pull/108))
- Add `handle_error` to `service::OnMethod` ([#160](https://github.com/tokio-rs/axum/pull/160))
## Breaking changes ## Breaking changes
@ -24,9 +24,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Ensure a `HandleError` service created from `ServiceExt::handle_error` - Ensure a `HandleError` service created from `ServiceExt::handle_error`
_does not_ implement `RoutingDsl` as that could lead to confusing routing _does not_ implement `RoutingDsl` as that could lead to confusing routing
behavior ([#120](https://github.com/tokio-rs/axum/pull/120)) behavior ([#120](https://github.com/tokio-rs/axum/pull/120))
- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156))
- Implement `routing::MethodFilter` via [`bitflags`](https://crates.io/crates/bitflags) - Implement `routing::MethodFilter` via [`bitflags`](https://crates.io/crates/bitflags)
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead - Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108)) - `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
- `ServiceExt` has been removed and its methods have been moved to `RoutingDsl` ([#160](https://github.com/tokio-rs/axum/pull/160))
- These future types have been moved - These future types have been moved
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved - `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133)) to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))

View file

@ -12,7 +12,6 @@ use axum::{
prelude::*, prelude::*,
response::IntoResponse, response::IntoResponse,
routing::BoxRoute, routing::BoxRoute,
service::ServiceExt,
}; };
use bytes::Bytes; use bytes::Bytes;
use http::StatusCode; use http::StatusCode;

View file

@ -4,7 +4,7 @@
//! cargo run --example sse --features=headers //! cargo run --example sse --features=headers
//! ``` //! ```
use axum::{extract::TypedHeader, prelude::*, routing::nest, service::ServiceExt, sse::Event}; use axum::{extract::TypedHeader, prelude::*, routing::nest, sse::Event};
use futures::stream::{self, Stream}; use futures::stream::{self, Stream};
use http::StatusCode; use http::StatusCode;
use std::{convert::Infallible, net::SocketAddr, time::Duration}; use std::{convert::Infallible, net::SocketAddr, time::Duration};
@ -19,20 +19,17 @@ async fn main() {
} }
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
// build our application with a route let static_files_service =
let app = nest( axum::service::get(ServeDir::new("examples/sse").append_index_html_on_directories(true))
"/",
axum::service::get(
ServeDir::new("examples/sse")
.append_index_html_on_directories(true)
.handle_error(|error: std::io::Error| { .handle_error(|error: std::io::Error| {
Ok::<_, std::convert::Infallible>(( Ok::<_, std::convert::Infallible>((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error), format!("Unhandled internal error: {}", error),
)) ))
}), });
),
) // build our application with a route
let app = nest("/", static_files_service)
.route("/sse", axum::sse::sse(make_stream)) .route("/sse", axum::sse::sse(make_stream))
.layer(TraceLayer::new_for_http()); .layer(TraceLayer::new_for_http());

View file

@ -4,7 +4,7 @@
//! cargo run --example static_file_server //! cargo run --example static_file_server
//! ``` //! ```
use axum::{prelude::*, routing::nest, service::ServiceExt}; use axum::{prelude::*, routing::nest};
use http::StatusCode; use http::StatusCode;
use std::net::SocketAddr; use std::net::SocketAddr;
use tower_http::{services::ServeDir, trace::TraceLayer}; use tower_http::{services::ServeDir, trace::TraceLayer};
@ -19,12 +19,12 @@ async fn main() {
let app = nest( let app = nest(
"/static", "/static",
axum::service::get(ServeDir::new(".").handle_error(|error: std::io::Error| { axum::service::get(ServeDir::new(".")).handle_error(|error: std::io::Error| {
Ok::<_, std::convert::Infallible>(( Ok::<_, std::convert::Infallible>((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error), format!("Unhandled internal error: {}", error),
)) ))
})), }),
) )
.layer(TraceLayer::new_for_http()); .layer(TraceLayer::new_for_http());

View file

@ -17,7 +17,6 @@ use axum::{
extract::{Extension, Json, Path, Query}, extract::{Extension, Json, Path, Query},
prelude::*, prelude::*,
response::IntoResponse, response::IntoResponse,
service::ServiceExt,
}; };
use http::StatusCode; use http::StatusCode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View file

@ -14,7 +14,6 @@ use axum::{
prelude::*, prelude::*,
response::IntoResponse, response::IntoResponse,
routing::nest, routing::nest,
service::ServiceExt,
}; };
use http::StatusCode; use http::StatusCode;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -35,15 +34,14 @@ async fn main() {
let app = nest( let app = nest(
"/", "/",
axum::service::get( axum::service::get(
ServeDir::new("examples/websocket") ServeDir::new("examples/websocket").append_index_html_on_directories(true),
.append_index_html_on_directories(true) )
.handle_error(|error: std::io::Error| { .handle_error(|error: std::io::Error| {
Ok::<_, std::convert::Infallible>(( Ok::<_, std::convert::Infallible>((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error), format!("Unhandled internal error: {}", error),
)) ))
}), }),
),
) )
// routes are matched from bottom to top, so we have to put `nest` at the // routes are matched from bottom to top, so we have to put `nest` at the
// top since it matches all routes // top since it matches all routes

View file

@ -366,9 +366,9 @@ impl<S, T> Layered<S, T> {
/// This is used to convert errors to responses rather than simply /// This is used to convert errors to responses rather than simply
/// terminating the connection. /// terminating the connection.
/// ///
/// It works similarly to [`routing::Layered::handle_error`]. See that for more details. /// It works similarly to [`routing::RoutingDsl::handle_error`]. See that for more details.
/// ///
/// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error /// [`routing::RoutingDsl::handle_error`]: crate::routing::RoutingDsl::handle_error
pub fn handle_error<F, ReqBody, ResBody, Res, E>( pub fn handle_error<F, ReqBody, ResBody, Res, E>(
self, self,
f: F, f: F,

View file

@ -493,7 +493,7 @@
//! return `Result<T, E>` where `T` implements //! return `Result<T, E>` where `T` implements
//! [`IntoResponse`](response::IntoResponse). //! [`IntoResponse`](response::IntoResponse).
//! //!
//! See [`routing::Layered::handle_error`] for more details. //! See [`routing::RoutingDsl::handle_error`] for more details.
//! //!
//! ## Applying multiple middleware //! ## Applying multiple middleware
//! //!

View file

@ -6,7 +6,7 @@ use crate::{
buffer::MpscBuffer, buffer::MpscBuffer,
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo}, extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
response::IntoResponse, response::IntoResponse,
service::HandleErrorFromRouter, service::{HandleError, HandleErrorFromRouter},
util::ByteStr, util::ByteStr,
}; };
use async_trait::async_trait; use async_trait::async_trait;
@ -348,6 +348,94 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized {
second: other, second: other,
} }
} }
/// Handle errors services in this router might produce, by mapping them to
/// responses.
///
/// Unhandled errors will close the connection without sending a response.
///
/// # Example
///
/// ```
/// use axum::{http::StatusCode, prelude::*};
/// use tower::{BoxError, timeout::TimeoutLayer};
/// use std::{time::Duration, convert::Infallible};
///
/// // This router can never fail, since handlers can never fail.
/// let app = route("/", get(|| async {}));
///
/// // Now the router can fail since the `tower::timeout::Timeout`
/// // middleware will return an error if the timeout elapses.
/// let app = app.layer(TimeoutLayer::new(Duration::from_secs(10)));
///
/// // With `handle_error` we can handle errors `Timeout` might produce.
/// // Our router now cannot fail, that is its error type is `Infallible`.
/// let app = app.handle_error(|error: BoxError| {
/// if error.is::<tower::timeout::error::Elapsed>() {
/// Ok::<_, Infallible>((
/// StatusCode::REQUEST_TIMEOUT,
/// "request took too long to handle".to_string(),
/// ))
/// } else {
/// Ok::<_, Infallible>((
/// StatusCode::INTERNAL_SERVER_ERROR,
/// format!("Unhandled error: {}", error),
/// ))
/// }
/// });
/// # async {
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// You can return `Err(_)` from the closure if you don't wish to handle
/// some errors:
///
/// ```
/// use axum::{http::StatusCode, prelude::*};
/// use tower::{BoxError, timeout::TimeoutLayer};
/// use std::time::Duration;
///
/// let app = route("/", get(|| async {}))
/// .layer(TimeoutLayer::new(Duration::from_secs(10)))
/// .handle_error(|error: BoxError| {
/// if error.is::<tower::timeout::error::Elapsed>() {
/// Ok((
/// StatusCode::REQUEST_TIMEOUT,
/// "request took too long to handle".to_string(),
/// ))
/// } else {
/// // return the error as is
/// Err(error)
/// }
/// });
/// # async {
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
fn handle_error<ReqBody, ResBody, F, Res, E>(
self,
f: F,
) -> HandleError<Self, F, ReqBody, HandleErrorFromRouter>
where
Self: Service<Request<ReqBody>, Response = Response<ResBody>>,
F: FnOnce(Self::Error) -> Result<Res, E>,
Res: IntoResponse,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
{
HandleError::new(self, f)
}
/// Check that your service cannot fail.
///
/// That is, its error type is [`Infallible`].
fn check_infallible<ReqBody>(self) -> Self
where
Self: Service<Request<ReqBody>, Error = Infallible>,
{
self
}
} }
impl<S, F> RoutingDsl for Route<S, F> {} impl<S, F> RoutingDsl for Route<S, F> {}
@ -646,97 +734,6 @@ impl<S> RoutingDsl for Layered<S> {}
impl<S> crate::sealed::Sealed for Layered<S> {} impl<S> crate::sealed::Sealed for Layered<S> {}
impl<S> Layered<S> {
/// Create a new [`Layered`] service where errors will be handled using the
/// given closure.
///
/// This is used to convert errors to responses rather than simply
/// terminating the connection.
///
/// That can be done using `handle_error` like so:
///
/// ```rust
/// use axum::prelude::*;
/// use http::StatusCode;
/// use tower::{BoxError, timeout::TimeoutLayer};
/// use std::{convert::Infallible, time::Duration};
///
/// async fn handler() { /* ... */ }
///
/// // `Timeout` will fail with `BoxError` if the timeout elapses...
/// let layered_app = route("/", get(handler))
/// .layer(TimeoutLayer::new(Duration::from_secs(30)));
///
/// // ...so we should handle that error
/// let with_errors_handled = layered_app.handle_error(|error: BoxError| {
/// if error.is::<tower::timeout::error::Elapsed>() {
/// Ok::<_, Infallible>((
/// StatusCode::REQUEST_TIMEOUT,
/// "request took too long".to_string(),
/// ))
/// } else {
/// Ok::<_, Infallible>((
/// StatusCode::INTERNAL_SERVER_ERROR,
/// format!("Unhandled internal error: {}", error),
/// ))
/// }
/// });
/// # async {
/// # axum::Server::bind(&"".parse().unwrap())
/// # .serve(with_errors_handled.into_make_service())
/// # .await
/// # .unwrap();
/// # };
/// ```
///
/// The closure must return `Result<T, E>` where `T` implements [`IntoResponse`].
///
/// You can also return `Err(_)` if you don't wish to handle the error:
///
/// ```rust
/// use axum::prelude::*;
/// use http::StatusCode;
/// use tower::{BoxError, timeout::TimeoutLayer};
/// use std::time::Duration;
///
/// async fn handler() { /* ... */ }
///
/// let layered_app = route("/", get(handler))
/// .layer(TimeoutLayer::new(Duration::from_secs(30)));
///
/// let with_errors_handled = layered_app.handle_error(|error: BoxError| {
/// if error.is::<tower::timeout::error::Elapsed>() {
/// Ok((
/// StatusCode::REQUEST_TIMEOUT,
/// "request took too long".to_string(),
/// ))
/// } else {
/// // keep the error as is
/// Err(error)
/// }
/// });
/// # async {
/// # axum::Server::bind(&"".parse().unwrap())
/// # .serve(with_errors_handled.into_make_service())
/// # .await
/// # .unwrap();
/// # };
/// ```
pub fn handle_error<F, ReqBody, ResBody, Res, E>(
self,
f: F,
) -> crate::service::HandleError<S, F, ReqBody, HandleErrorFromRouter>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
F: FnOnce(S::Error) -> Result<Res, E>,
Res: IntoResponse,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
{
crate::service::HandleError::new(self.inner, f)
}
}
impl<S, R> Service<R> for Layered<S> impl<S, R> Service<R> for Layered<S>
where where
S: Service<R>, S: Service<R>,
@ -809,7 +806,7 @@ where
/// ///
/// ``` /// ```
/// use axum::{ /// use axum::{
/// routing::nest, service::{get, ServiceExt}, prelude::*, /// routing::nest, service::get, prelude::*,
/// }; /// };
/// use tower_http::services::ServeDir; /// use tower_http::services::ServeDir;
/// ///

View file

@ -94,7 +94,6 @@ use crate::{
use bytes::Bytes; use bytes::Bytes;
use http::{Request, Response}; use http::{Request, Response};
use std::{ use std::{
convert::Infallible,
fmt, fmt,
marker::PhantomData, marker::PhantomData,
task::{Context, Poll}, task::{Context, Poll},
@ -428,6 +427,26 @@ impl<S, F> OnMethod<S, F> {
fallback: self, fallback: self,
} }
} }
/// Handle errors this service might produce, by mapping them to responses.
///
/// Unhandled errors will close the connection without sending a response.
///
/// Works similarly to [`RoutingDsl::handle_error`]. See that for more
/// details.
///
/// [`RoutingDsl::handle_error`]: crate::routing::RoutingDsl::handle_error
pub fn handle_error<ReqBody, H, Res, E>(
self,
f: H,
) -> HandleError<Self, H, ReqBody, HandleErrorFromService>
where
Self: Service<Request<ReqBody>, Response = Response<BoxBody>>,
H: FnOnce(<Self as Service<Request<ReqBody>>>::Error) -> Result<Res, E>,
Res: IntoResponse,
{
HandleError::new(self, f)
}
} }
// this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean // this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean
@ -462,7 +481,7 @@ where
/// ///
/// Created with /// Created with
/// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or /// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or
/// [`routing::Layered::handle_error`](crate::routing::Layered::handle_error). /// [`routing::RoutingDsl::handle_error`](crate::routing::RoutingDsl::handle_error).
/// See those methods for more details. /// See those methods for more details.
pub struct HandleError<S, F, B, T> { pub struct HandleError<S, F, B, T> {
inner: S, inner: S,
@ -542,75 +561,6 @@ where
} }
} }
/// Extension trait that adds additional methods to [`Service`].
pub trait ServiceExt<ReqBody, ResBody>:
Service<Request<ReqBody>, Response = Response<ResBody>>
{
/// Handle errors from a service.
///
/// `handle_error` takes a closure that will map errors from the service
/// into responses. The closure's return type must be `Result<T, E>` where
/// `T` implements [`IntoResponse`](crate::response::IntoResponse).
///
/// # Example
///
/// ```rust,no_run
/// use axum::{service::{self, ServiceExt}, prelude::*};
/// use http::{Response, StatusCode};
/// use tower::{service_fn, BoxError};
/// use std::convert::Infallible;
///
/// // A service that might fail with `std::io::Error`
/// let service = service_fn(|_: Request<Body>| async {
/// let res = Response::new(Body::empty());
/// Ok::<_, std::io::Error>(res)
/// });
///
/// let app = route(
/// "/",
/// service.handle_error(|error: std::io::Error| {
/// Ok::<_, Infallible>((
/// StatusCode::INTERNAL_SERVER_ERROR,
/// error.to_string(),
/// ))
/// }),
/// );
/// #
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// It works similarly to [`routing::Layered::handle_error`]. See that for more details.
///
/// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error
fn handle_error<F, Res, E>(self, f: F) -> HandleError<Self, F, ReqBody, HandleErrorFromService>
where
Self: Sized,
F: FnOnce(Self::Error) -> Result<Res, E>,
Res: IntoResponse,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
{
HandleError::new(self, f)
}
/// Check that your service cannot fail.
///
/// That is, its error type is [`Infallible`].
fn check_infallible(self) -> Self
where
Self: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible> + Sized,
{
self
}
}
impl<S, ReqBody, ResBody> ServiceExt<ReqBody, ResBody> for S where
S: Service<Request<ReqBody>, Response = Response<ResBody>>
{
}
/// A [`Service`] that boxes response bodies. /// A [`Service`] that boxes response bodies.
pub struct BoxResponseBody<S, B> { pub struct BoxResponseBody<S, B> {
inner: S, inner: S,

View file

@ -16,9 +16,9 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
time::Duration, time::Duration,
}; };
use tower::{make::Shared, service_fn, BoxError, Service, ServiceBuilder}; use tower::{make::Shared, service_fn, BoxError, Service};
use tower_http::{compression::CompressionLayer, trace::TraceLayer};
mod handle_error;
mod nest; mod nest;
mod or; mod or;
@ -326,53 +326,6 @@ async fn boxing() {
assert_eq!(res.text().await.unwrap(), "hi from POST"); assert_eq!(res.text().await.unwrap(), "hi from POST");
} }
#[tokio::test]
async fn service_handlers() {
use crate::service::ServiceExt as _;
use tower_http::services::ServeFile;
let app = route(
"/echo",
service::post(
service_fn(|req: Request<Body>| async move {
Ok::<_, BoxError>(Response::new(req.into_body()))
})
.handle_error(|_error: BoxError| Ok(StatusCode::INTERNAL_SERVER_ERROR)),
),
)
.route(
"/static/Cargo.toml",
service::on(
MethodFilter::GET,
ServeFile::new("Cargo.toml").handle_error(|error: std::io::Error| {
Ok::<_, Infallible>((StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))
}),
),
);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.post(format!("http://{}/echo", addr))
.body("foobar")
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "foobar");
let res = client
.get(format!("http://{}/static/Cargo.toml", addr))
.body("foobar")
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert!(res.text().await.unwrap().contains("edition ="));
}
#[tokio::test] #[tokio::test]
async fn routing_between_services() { async fn routing_between_services() {
use std::convert::Infallible; use std::convert::Infallible;
@ -466,55 +419,6 @@ async fn middleware_on_single_route() {
assert_eq!(body, "Hello, World!"); assert_eq!(body, "Hello, World!");
} }
#[tokio::test]
async fn handling_errors_from_layered_single_routes() {
async fn handle(_req: Request<Body>) -> &'static str {
tokio::time::sleep(Duration::from_secs(10)).await;
""
}
let app = route(
"/",
get(handle
.layer(
ServiceBuilder::new()
.timeout(Duration::from_millis(100))
.layer(TraceLayer::new_for_http())
.into_inner(),
)
.handle_error(|_error: BoxError| {
Ok::<_, Infallible>(StatusCode::INTERNAL_SERVER_ERROR)
})),
);
let addr = run_in_background(app).await;
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn layer_on_whole_router() {
async fn handle(_req: Request<Body>) -> &'static str {
tokio::time::sleep(Duration::from_secs(10)).await;
""
}
let app = route("/", get(handle))
.layer(
ServiceBuilder::new()
.layer(CompressionLayer::new())
.timeout(Duration::from_millis(100))
.into_inner(),
)
.handle_error(|_err: BoxError| Ok::<_, Infallible>(StatusCode::INTERNAL_SERVER_ERROR));
let addr = run_in_background(app).await;
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test] #[tokio::test]
#[cfg(feature = "header")] #[cfg(feature = "header")]
async fn typed_header() { async fn typed_header() {

203
src/tests/handle_error.rs Normal file
View file

@ -0,0 +1,203 @@
use super::*;
use futures_util::future::{pending, ready};
use tower::{timeout::TimeoutLayer, MakeService};
async fn unit() {}
async fn forever() {
pending().await
}
fn timeout() -> TimeoutLayer {
TimeoutLayer::new(Duration::from_millis(10))
}
#[derive(Clone)]
struct Svc;
impl<R> Service<R> for Svc {
type Response = Response<Body>;
type Error = hyper::Error;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: R) -> Self::Future {
ready(Ok(Response::new(Body::empty())))
}
}
fn check_make_svc<M, R, T, E>(_make_svc: M)
where
M: MakeService<(), R, Response = T, Error = E>,
{
}
fn handle_error<E>(_: E) -> Result<StatusCode, Infallible> {
Ok(StatusCode::INTERNAL_SERVER_ERROR)
}
#[tokio::test]
async fn handler() {
let app = route(
"/",
get(forever
.layer(timeout())
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT))),
);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
}
#[tokio::test]
async fn handler_multiple_methods_first() {
let app = route(
"/",
get(forever
.layer(timeout())
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)))
.post(unit),
);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
}
#[tokio::test]
async fn handler_multiple_methods_middle() {
let app = route(
"/",
delete(unit)
.get(
forever
.layer(timeout())
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)),
)
.post(unit),
);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
}
#[tokio::test]
async fn handler_multiple_methods_last() {
let app = route(
"/",
delete(unit).get(
forever
.layer(timeout())
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)),
),
);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
}
#[test]
fn service_propagates_errors() {
let app = route::<_, Body>("/echo", service::post(Svc));
check_make_svc::<_, _, _, hyper::Error>(app.into_make_service());
}
#[test]
fn service_nested_propagates_errors() {
let app = route::<_, Body>("/echo", nest("/foo", service::post(Svc)));
check_make_svc::<_, _, _, hyper::Error>(app.into_make_service());
}
#[test]
fn service_handle_on_method() {
let app = route::<_, Body>(
"/echo",
service::get(Svc).handle_error(handle_error::<hyper::Error>),
);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
#[test]
fn service_handle_on_method_multiple() {
let app = route::<_, Body>(
"/echo",
service::get(Svc)
.post(Svc)
.handle_error(handle_error::<hyper::Error>),
);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
#[test]
fn service_handle_on_router() {
let app =
route::<_, Body>("/echo", service::get(Svc)).handle_error(handle_error::<hyper::Error>);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
#[test]
fn service_handle_on_router_still_impls_routing_dsl() {
let app = route::<_, Body>("/echo", service::get(Svc))
.handle_error(handle_error::<hyper::Error>)
.route("/", get(unit));
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
#[test]
fn layered() {
let app = route::<_, Body>("/echo", get(unit))
.layer(timeout())
.handle_error(handle_error::<BoxError>);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
#[tokio::test] // async because of `.boxed()`
async fn layered_boxed() {
let app = route::<_, Body>("/echo", get(unit))
.layer(timeout())
.boxed()
.handle_error(handle_error::<BoxError>);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}