mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-28 07:20:12 +01:00
Move methods from ServiceExt
to RoutingDsl
(#160)
Previously, on `main`, this wouldn't compile: ```rust let app = route("/", get(handler)) .layer( ServiceBuilder::new() .timeout(Duration::from_secs(10)) .into_inner(), ) .handle_error(...) .route(...); // <-- doesn't work ``` That is because `handle_error` would be `axum::service::ServiceExt::handle_error` which returns `HandleError<_, _, _, HandleErrorFromService>` which does _not_ implement `RoutingDsl`. So you couldn't call `route`. This was caused by https://github.com/tokio-rs/axum/pull/120. Basically `handle_error` when called on a `RoutingDsl`, the resulting service should also implement `RoutingDsl`, but if called on another random service it should _not_ implement `RoutingDsl`. I don't think thats possible by having `handle_error` on `ServiceExt` which is implemented for any service, since all axum routers are also services by design. This resolves the issue by removing `ServiceExt` and moving its methods to `RoutingDsl`. Then we have more tight control over what has a `handle_error` method. `service::OnMethod` now also has a `handle_error` so you can still handle errors from random services, by doing `service::any(svc).handle_error(...)`.
This commit is contained in:
parent
9b3f3c9bdf
commit
8013165908
12 changed files with 346 additions and 297 deletions
|
@ -9,8 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
- Make `FromRequest` default to being generic over `body::Body` ([#146](https://github.com/tokio-rs/axum/pull/146))
|
||||
- Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153))
|
||||
- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156))
|
||||
- Add `RoutingDsl::or` for combining routes ([#108](https://github.com/tokio-rs/axum/pull/108))
|
||||
- Add `handle_error` to `service::OnMethod` ([#160](https://github.com/tokio-rs/axum/pull/160))
|
||||
|
||||
## Breaking changes
|
||||
|
||||
|
@ -24,9 +24,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Ensure a `HandleError` service created from `ServiceExt::handle_error`
|
||||
_does not_ implement `RoutingDsl` as that could lead to confusing routing
|
||||
behavior ([#120](https://github.com/tokio-rs/axum/pull/120))
|
||||
- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156))
|
||||
- Implement `routing::MethodFilter` via [`bitflags`](https://crates.io/crates/bitflags)
|
||||
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
|
||||
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
|
||||
- `ServiceExt` has been removed and its methods have been moved to `RoutingDsl` ([#160](https://github.com/tokio-rs/axum/pull/160))
|
||||
- These future types have been moved
|
||||
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
|
||||
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))
|
||||
|
|
|
@ -12,7 +12,6 @@ use axum::{
|
|||
prelude::*,
|
||||
response::IntoResponse,
|
||||
routing::BoxRoute,
|
||||
service::ServiceExt,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use http::StatusCode;
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
//! cargo run --example sse --features=headers
|
||||
//! ```
|
||||
|
||||
use axum::{extract::TypedHeader, prelude::*, routing::nest, service::ServiceExt, sse::Event};
|
||||
use axum::{extract::TypedHeader, prelude::*, routing::nest, sse::Event};
|
||||
use futures::stream::{self, Stream};
|
||||
use http::StatusCode;
|
||||
use std::{convert::Infallible, net::SocketAddr, time::Duration};
|
||||
|
@ -19,22 +19,19 @@ async fn main() {
|
|||
}
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let static_files_service =
|
||||
axum::service::get(ServeDir::new("examples/sse").append_index_html_on_directories(true))
|
||||
.handle_error(|error: std::io::Error| {
|
||||
Ok::<_, std::convert::Infallible>((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
))
|
||||
});
|
||||
|
||||
// build our application with a route
|
||||
let app = nest(
|
||||
"/",
|
||||
axum::service::get(
|
||||
ServeDir::new("examples/sse")
|
||||
.append_index_html_on_directories(true)
|
||||
.handle_error(|error: std::io::Error| {
|
||||
Ok::<_, std::convert::Infallible>((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
))
|
||||
}),
|
||||
),
|
||||
)
|
||||
.route("/sse", axum::sse::sse(make_stream))
|
||||
.layer(TraceLayer::new_for_http());
|
||||
let app = nest("/", static_files_service)
|
||||
.route("/sse", axum::sse::sse(make_stream))
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
// run it
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
//! cargo run --example static_file_server
|
||||
//! ```
|
||||
|
||||
use axum::{prelude::*, routing::nest, service::ServiceExt};
|
||||
use axum::{prelude::*, routing::nest};
|
||||
use http::StatusCode;
|
||||
use std::net::SocketAddr;
|
||||
use tower_http::{services::ServeDir, trace::TraceLayer};
|
||||
|
@ -19,12 +19,12 @@ async fn main() {
|
|||
|
||||
let app = nest(
|
||||
"/static",
|
||||
axum::service::get(ServeDir::new(".").handle_error(|error: std::io::Error| {
|
||||
axum::service::get(ServeDir::new(".")).handle_error(|error: std::io::Error| {
|
||||
Ok::<_, std::convert::Infallible>((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
))
|
||||
})),
|
||||
}),
|
||||
)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@ use axum::{
|
|||
extract::{Extension, Json, Path, Query},
|
||||
prelude::*,
|
||||
response::IntoResponse,
|
||||
service::ServiceExt,
|
||||
};
|
||||
use http::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
|
@ -14,7 +14,6 @@ use axum::{
|
|||
prelude::*,
|
||||
response::IntoResponse,
|
||||
routing::nest,
|
||||
service::ServiceExt,
|
||||
};
|
||||
use http::StatusCode;
|
||||
use std::net::SocketAddr;
|
||||
|
@ -35,15 +34,14 @@ async fn main() {
|
|||
let app = nest(
|
||||
"/",
|
||||
axum::service::get(
|
||||
ServeDir::new("examples/websocket")
|
||||
.append_index_html_on_directories(true)
|
||||
.handle_error(|error: std::io::Error| {
|
||||
Ok::<_, std::convert::Infallible>((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
))
|
||||
}),
|
||||
),
|
||||
ServeDir::new("examples/websocket").append_index_html_on_directories(true),
|
||||
)
|
||||
.handle_error(|error: std::io::Error| {
|
||||
Ok::<_, std::convert::Infallible>((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
))
|
||||
}),
|
||||
)
|
||||
// routes are matched from bottom to top, so we have to put `nest` at the
|
||||
// top since it matches all routes
|
||||
|
|
|
@ -366,9 +366,9 @@ impl<S, T> Layered<S, T> {
|
|||
/// This is used to convert errors to responses rather than simply
|
||||
/// terminating the connection.
|
||||
///
|
||||
/// It works similarly to [`routing::Layered::handle_error`]. See that for more details.
|
||||
/// It works similarly to [`routing::RoutingDsl::handle_error`]. See that for more details.
|
||||
///
|
||||
/// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error
|
||||
/// [`routing::RoutingDsl::handle_error`]: crate::routing::RoutingDsl::handle_error
|
||||
pub fn handle_error<F, ReqBody, ResBody, Res, E>(
|
||||
self,
|
||||
f: F,
|
||||
|
|
|
@ -493,7 +493,7 @@
|
|||
//! return `Result<T, E>` where `T` implements
|
||||
//! [`IntoResponse`](response::IntoResponse).
|
||||
//!
|
||||
//! See [`routing::Layered::handle_error`] for more details.
|
||||
//! See [`routing::RoutingDsl::handle_error`] for more details.
|
||||
//!
|
||||
//! ## Applying multiple middleware
|
||||
//!
|
||||
|
|
183
src/routing.rs
183
src/routing.rs
|
@ -6,7 +6,7 @@ use crate::{
|
|||
buffer::MpscBuffer,
|
||||
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
||||
response::IntoResponse,
|
||||
service::HandleErrorFromRouter,
|
||||
service::{HandleError, HandleErrorFromRouter},
|
||||
util::ByteStr,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
|
@ -348,6 +348,94 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized {
|
|||
second: other,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle errors services in this router might produce, by mapping them to
|
||||
/// responses.
|
||||
///
|
||||
/// Unhandled errors will close the connection without sending a response.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{http::StatusCode, prelude::*};
|
||||
/// use tower::{BoxError, timeout::TimeoutLayer};
|
||||
/// use std::{time::Duration, convert::Infallible};
|
||||
///
|
||||
/// // This router can never fail, since handlers can never fail.
|
||||
/// let app = route("/", get(|| async {}));
|
||||
///
|
||||
/// // Now the router can fail since the `tower::timeout::Timeout`
|
||||
/// // middleware will return an error if the timeout elapses.
|
||||
/// let app = app.layer(TimeoutLayer::new(Duration::from_secs(10)));
|
||||
///
|
||||
/// // With `handle_error` we can handle errors `Timeout` might produce.
|
||||
/// // Our router now cannot fail, that is its error type is `Infallible`.
|
||||
/// let app = app.handle_error(|error: BoxError| {
|
||||
/// if error.is::<tower::timeout::error::Elapsed>() {
|
||||
/// Ok::<_, Infallible>((
|
||||
/// StatusCode::REQUEST_TIMEOUT,
|
||||
/// "request took too long to handle".to_string(),
|
||||
/// ))
|
||||
/// } else {
|
||||
/// Ok::<_, Infallible>((
|
||||
/// StatusCode::INTERNAL_SERVER_ERROR,
|
||||
/// format!("Unhandled error: {}", error),
|
||||
/// ))
|
||||
/// }
|
||||
/// });
|
||||
/// # async {
|
||||
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
///
|
||||
/// You can return `Err(_)` from the closure if you don't wish to handle
|
||||
/// some errors:
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{http::StatusCode, prelude::*};
|
||||
/// use tower::{BoxError, timeout::TimeoutLayer};
|
||||
/// use std::time::Duration;
|
||||
///
|
||||
/// let app = route("/", get(|| async {}))
|
||||
/// .layer(TimeoutLayer::new(Duration::from_secs(10)))
|
||||
/// .handle_error(|error: BoxError| {
|
||||
/// if error.is::<tower::timeout::error::Elapsed>() {
|
||||
/// Ok((
|
||||
/// StatusCode::REQUEST_TIMEOUT,
|
||||
/// "request took too long to handle".to_string(),
|
||||
/// ))
|
||||
/// } else {
|
||||
/// // return the error as is
|
||||
/// Err(error)
|
||||
/// }
|
||||
/// });
|
||||
/// # async {
|
||||
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
fn handle_error<ReqBody, ResBody, F, Res, E>(
|
||||
self,
|
||||
f: F,
|
||||
) -> HandleError<Self, F, ReqBody, HandleErrorFromRouter>
|
||||
where
|
||||
Self: Service<Request<ReqBody>, Response = Response<ResBody>>,
|
||||
F: FnOnce(Self::Error) -> Result<Res, E>,
|
||||
Res: IntoResponse,
|
||||
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
||||
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
|
||||
{
|
||||
HandleError::new(self, f)
|
||||
}
|
||||
|
||||
/// Check that your service cannot fail.
|
||||
///
|
||||
/// That is, its error type is [`Infallible`].
|
||||
fn check_infallible<ReqBody>(self) -> Self
|
||||
where
|
||||
Self: Service<Request<ReqBody>, Error = Infallible>,
|
||||
{
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, F> RoutingDsl for Route<S, F> {}
|
||||
|
@ -646,97 +734,6 @@ impl<S> RoutingDsl for Layered<S> {}
|
|||
|
||||
impl<S> crate::sealed::Sealed for Layered<S> {}
|
||||
|
||||
impl<S> Layered<S> {
|
||||
/// Create a new [`Layered`] service where errors will be handled using the
|
||||
/// given closure.
|
||||
///
|
||||
/// This is used to convert errors to responses rather than simply
|
||||
/// terminating the connection.
|
||||
///
|
||||
/// That can be done using `handle_error` like so:
|
||||
///
|
||||
/// ```rust
|
||||
/// use axum::prelude::*;
|
||||
/// use http::StatusCode;
|
||||
/// use tower::{BoxError, timeout::TimeoutLayer};
|
||||
/// use std::{convert::Infallible, time::Duration};
|
||||
///
|
||||
/// async fn handler() { /* ... */ }
|
||||
///
|
||||
/// // `Timeout` will fail with `BoxError` if the timeout elapses...
|
||||
/// let layered_app = route("/", get(handler))
|
||||
/// .layer(TimeoutLayer::new(Duration::from_secs(30)));
|
||||
///
|
||||
/// // ...so we should handle that error
|
||||
/// let with_errors_handled = layered_app.handle_error(|error: BoxError| {
|
||||
/// if error.is::<tower::timeout::error::Elapsed>() {
|
||||
/// Ok::<_, Infallible>((
|
||||
/// StatusCode::REQUEST_TIMEOUT,
|
||||
/// "request took too long".to_string(),
|
||||
/// ))
|
||||
/// } else {
|
||||
/// Ok::<_, Infallible>((
|
||||
/// StatusCode::INTERNAL_SERVER_ERROR,
|
||||
/// format!("Unhandled internal error: {}", error),
|
||||
/// ))
|
||||
/// }
|
||||
/// });
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap())
|
||||
/// # .serve(with_errors_handled.into_make_service())
|
||||
/// # .await
|
||||
/// # .unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
///
|
||||
/// The closure must return `Result<T, E>` where `T` implements [`IntoResponse`].
|
||||
///
|
||||
/// You can also return `Err(_)` if you don't wish to handle the error:
|
||||
///
|
||||
/// ```rust
|
||||
/// use axum::prelude::*;
|
||||
/// use http::StatusCode;
|
||||
/// use tower::{BoxError, timeout::TimeoutLayer};
|
||||
/// use std::time::Duration;
|
||||
///
|
||||
/// async fn handler() { /* ... */ }
|
||||
///
|
||||
/// let layered_app = route("/", get(handler))
|
||||
/// .layer(TimeoutLayer::new(Duration::from_secs(30)));
|
||||
///
|
||||
/// let with_errors_handled = layered_app.handle_error(|error: BoxError| {
|
||||
/// if error.is::<tower::timeout::error::Elapsed>() {
|
||||
/// Ok((
|
||||
/// StatusCode::REQUEST_TIMEOUT,
|
||||
/// "request took too long".to_string(),
|
||||
/// ))
|
||||
/// } else {
|
||||
/// // keep the error as is
|
||||
/// Err(error)
|
||||
/// }
|
||||
/// });
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap())
|
||||
/// # .serve(with_errors_handled.into_make_service())
|
||||
/// # .await
|
||||
/// # .unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
pub fn handle_error<F, ReqBody, ResBody, Res, E>(
|
||||
self,
|
||||
f: F,
|
||||
) -> crate::service::HandleError<S, F, ReqBody, HandleErrorFromRouter>
|
||||
where
|
||||
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
|
||||
F: FnOnce(S::Error) -> Result<Res, E>,
|
||||
Res: IntoResponse,
|
||||
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
||||
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
|
||||
{
|
||||
crate::service::HandleError::new(self.inner, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, R> Service<R> for Layered<S>
|
||||
where
|
||||
S: Service<R>,
|
||||
|
@ -809,7 +806,7 @@ where
|
|||
///
|
||||
/// ```
|
||||
/// use axum::{
|
||||
/// routing::nest, service::{get, ServiceExt}, prelude::*,
|
||||
/// routing::nest, service::get, prelude::*,
|
||||
/// };
|
||||
/// use tower_http::services::ServeDir;
|
||||
///
|
||||
|
|
|
@ -94,7 +94,6 @@ use crate::{
|
|||
use bytes::Bytes;
|
||||
use http::{Request, Response};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
fmt,
|
||||
marker::PhantomData,
|
||||
task::{Context, Poll},
|
||||
|
@ -428,6 +427,26 @@ impl<S, F> OnMethod<S, F> {
|
|||
fallback: self,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle errors this service might produce, by mapping them to responses.
|
||||
///
|
||||
/// Unhandled errors will close the connection without sending a response.
|
||||
///
|
||||
/// Works similarly to [`RoutingDsl::handle_error`]. See that for more
|
||||
/// details.
|
||||
///
|
||||
/// [`RoutingDsl::handle_error`]: crate::routing::RoutingDsl::handle_error
|
||||
pub fn handle_error<ReqBody, H, Res, E>(
|
||||
self,
|
||||
f: H,
|
||||
) -> HandleError<Self, H, ReqBody, HandleErrorFromService>
|
||||
where
|
||||
Self: Service<Request<ReqBody>, Response = Response<BoxBody>>,
|
||||
H: FnOnce(<Self as Service<Request<ReqBody>>>::Error) -> Result<Res, E>,
|
||||
Res: IntoResponse,
|
||||
{
|
||||
HandleError::new(self, f)
|
||||
}
|
||||
}
|
||||
|
||||
// this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean
|
||||
|
@ -462,7 +481,7 @@ where
|
|||
///
|
||||
/// Created with
|
||||
/// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or
|
||||
/// [`routing::Layered::handle_error`](crate::routing::Layered::handle_error).
|
||||
/// [`routing::RoutingDsl::handle_error`](crate::routing::RoutingDsl::handle_error).
|
||||
/// See those methods for more details.
|
||||
pub struct HandleError<S, F, B, T> {
|
||||
inner: S,
|
||||
|
@ -542,75 +561,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Extension trait that adds additional methods to [`Service`].
|
||||
pub trait ServiceExt<ReqBody, ResBody>:
|
||||
Service<Request<ReqBody>, Response = Response<ResBody>>
|
||||
{
|
||||
/// Handle errors from a service.
|
||||
///
|
||||
/// `handle_error` takes a closure that will map errors from the service
|
||||
/// into responses. The closure's return type must be `Result<T, E>` where
|
||||
/// `T` implements [`IntoResponse`](crate::response::IntoResponse).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use axum::{service::{self, ServiceExt}, prelude::*};
|
||||
/// use http::{Response, StatusCode};
|
||||
/// use tower::{service_fn, BoxError};
|
||||
/// use std::convert::Infallible;
|
||||
///
|
||||
/// // A service that might fail with `std::io::Error`
|
||||
/// let service = service_fn(|_: Request<Body>| async {
|
||||
/// let res = Response::new(Body::empty());
|
||||
/// Ok::<_, std::io::Error>(res)
|
||||
/// });
|
||||
///
|
||||
/// let app = route(
|
||||
/// "/",
|
||||
/// service.handle_error(|error: std::io::Error| {
|
||||
/// Ok::<_, Infallible>((
|
||||
/// StatusCode::INTERNAL_SERVER_ERROR,
|
||||
/// error.to_string(),
|
||||
/// ))
|
||||
/// }),
|
||||
/// );
|
||||
/// #
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
///
|
||||
/// It works similarly to [`routing::Layered::handle_error`]. See that for more details.
|
||||
///
|
||||
/// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error
|
||||
fn handle_error<F, Res, E>(self, f: F) -> HandleError<Self, F, ReqBody, HandleErrorFromService>
|
||||
where
|
||||
Self: Sized,
|
||||
F: FnOnce(Self::Error) -> Result<Res, E>,
|
||||
Res: IntoResponse,
|
||||
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
||||
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
|
||||
{
|
||||
HandleError::new(self, f)
|
||||
}
|
||||
|
||||
/// Check that your service cannot fail.
|
||||
///
|
||||
/// That is, its error type is [`Infallible`].
|
||||
fn check_infallible(self) -> Self
|
||||
where
|
||||
Self: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible> + Sized,
|
||||
{
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, ReqBody, ResBody> ServiceExt<ReqBody, ResBody> for S where
|
||||
S: Service<Request<ReqBody>, Response = Response<ResBody>>
|
||||
{
|
||||
}
|
||||
|
||||
/// A [`Service`] that boxes response bodies.
|
||||
pub struct BoxResponseBody<S, B> {
|
||||
inner: S,
|
||||
|
|
100
src/tests.rs
100
src/tests.rs
|
@ -16,9 +16,9 @@ use std::{
|
|||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tower::{make::Shared, service_fn, BoxError, Service, ServiceBuilder};
|
||||
use tower_http::{compression::CompressionLayer, trace::TraceLayer};
|
||||
use tower::{make::Shared, service_fn, BoxError, Service};
|
||||
|
||||
mod handle_error;
|
||||
mod nest;
|
||||
mod or;
|
||||
|
||||
|
@ -326,53 +326,6 @@ async fn boxing() {
|
|||
assert_eq!(res.text().await.unwrap(), "hi from POST");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn service_handlers() {
|
||||
use crate::service::ServiceExt as _;
|
||||
use tower_http::services::ServeFile;
|
||||
|
||||
let app = route(
|
||||
"/echo",
|
||||
service::post(
|
||||
service_fn(|req: Request<Body>| async move {
|
||||
Ok::<_, BoxError>(Response::new(req.into_body()))
|
||||
})
|
||||
.handle_error(|_error: BoxError| Ok(StatusCode::INTERNAL_SERVER_ERROR)),
|
||||
),
|
||||
)
|
||||
.route(
|
||||
"/static/Cargo.toml",
|
||||
service::on(
|
||||
MethodFilter::GET,
|
||||
ServeFile::new("Cargo.toml").handle_error(|error: std::io::Error| {
|
||||
Ok::<_, Infallible>((StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))
|
||||
}),
|
||||
),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.post(format!("http://{}/echo", addr))
|
||||
.body("foobar")
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), "foobar");
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/static/Cargo.toml", addr))
|
||||
.body("foobar")
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert!(res.text().await.unwrap().contains("edition ="));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn routing_between_services() {
|
||||
use std::convert::Infallible;
|
||||
|
@ -466,55 +419,6 @@ async fn middleware_on_single_route() {
|
|||
assert_eq!(body, "Hello, World!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handling_errors_from_layered_single_routes() {
|
||||
async fn handle(_req: Request<Body>) -> &'static str {
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
""
|
||||
}
|
||||
|
||||
let app = route(
|
||||
"/",
|
||||
get(handle
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.timeout(Duration::from_millis(100))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.into_inner(),
|
||||
)
|
||||
.handle_error(|_error: BoxError| {
|
||||
Ok::<_, Infallible>(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
})),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn layer_on_whole_router() {
|
||||
async fn handle(_req: Request<Body>) -> &'static str {
|
||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
||||
""
|
||||
}
|
||||
|
||||
let app = route("/", get(handle))
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(CompressionLayer::new())
|
||||
.timeout(Duration::from_millis(100))
|
||||
.into_inner(),
|
||||
)
|
||||
.handle_error(|_err: BoxError| Ok::<_, Infallible>(StatusCode::INTERNAL_SERVER_ERROR));
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg(feature = "header")]
|
||||
async fn typed_header() {
|
||||
|
|
203
src/tests/handle_error.rs
Normal file
203
src/tests/handle_error.rs
Normal file
|
@ -0,0 +1,203 @@
|
|||
use super::*;
|
||||
use futures_util::future::{pending, ready};
|
||||
use tower::{timeout::TimeoutLayer, MakeService};
|
||||
|
||||
async fn unit() {}
|
||||
|
||||
async fn forever() {
|
||||
pending().await
|
||||
}
|
||||
|
||||
fn timeout() -> TimeoutLayer {
|
||||
TimeoutLayer::new(Duration::from_millis(10))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Svc;
|
||||
|
||||
impl<R> Service<R> for Svc {
|
||||
type Response = Response<Body>;
|
||||
type Error = hyper::Error;
|
||||
type Future = Ready<Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: R) -> Self::Future {
|
||||
ready(Ok(Response::new(Body::empty())))
|
||||
}
|
||||
}
|
||||
|
||||
fn check_make_svc<M, R, T, E>(_make_svc: M)
|
||||
where
|
||||
M: MakeService<(), R, Response = T, Error = E>,
|
||||
{
|
||||
}
|
||||
|
||||
fn handle_error<E>(_: E) -> Result<StatusCode, Infallible> {
|
||||
Ok(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler() {
|
||||
let app = route(
|
||||
"/",
|
||||
get(forever
|
||||
.layer(timeout())
|
||||
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT))),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_multiple_methods_first() {
|
||||
let app = route(
|
||||
"/",
|
||||
get(forever
|
||||
.layer(timeout())
|
||||
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)))
|
||||
.post(unit),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_multiple_methods_middle() {
|
||||
let app = route(
|
||||
"/",
|
||||
delete(unit)
|
||||
.get(
|
||||
forever
|
||||
.layer(timeout())
|
||||
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)),
|
||||
)
|
||||
.post(unit),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_multiple_methods_last() {
|
||||
let app = route(
|
||||
"/",
|
||||
delete(unit).get(
|
||||
forever
|
||||
.layer(timeout())
|
||||
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)),
|
||||
),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn service_propagates_errors() {
|
||||
let app = route::<_, Body>("/echo", service::post(Svc));
|
||||
|
||||
check_make_svc::<_, _, _, hyper::Error>(app.into_make_service());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn service_nested_propagates_errors() {
|
||||
let app = route::<_, Body>("/echo", nest("/foo", service::post(Svc)));
|
||||
|
||||
check_make_svc::<_, _, _, hyper::Error>(app.into_make_service());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn service_handle_on_method() {
|
||||
let app = route::<_, Body>(
|
||||
"/echo",
|
||||
service::get(Svc).handle_error(handle_error::<hyper::Error>),
|
||||
);
|
||||
|
||||
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn service_handle_on_method_multiple() {
|
||||
let app = route::<_, Body>(
|
||||
"/echo",
|
||||
service::get(Svc)
|
||||
.post(Svc)
|
||||
.handle_error(handle_error::<hyper::Error>),
|
||||
);
|
||||
|
||||
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn service_handle_on_router() {
|
||||
let app =
|
||||
route::<_, Body>("/echo", service::get(Svc)).handle_error(handle_error::<hyper::Error>);
|
||||
|
||||
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn service_handle_on_router_still_impls_routing_dsl() {
|
||||
let app = route::<_, Body>("/echo", service::get(Svc))
|
||||
.handle_error(handle_error::<hyper::Error>)
|
||||
.route("/", get(unit));
|
||||
|
||||
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn layered() {
|
||||
let app = route::<_, Body>("/echo", get(unit))
|
||||
.layer(timeout())
|
||||
.handle_error(handle_error::<BoxError>);
|
||||
|
||||
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||
}
|
||||
|
||||
#[tokio::test] // async because of `.boxed()`
|
||||
async fn layered_boxed() {
|
||||
let app = route::<_, Body>("/echo", get(unit))
|
||||
.layer(timeout())
|
||||
.boxed()
|
||||
.handle_error(handle_error::<BoxError>);
|
||||
|
||||
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||
}
|
Loading…
Reference in a new issue