mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-28 15:30:16 +01:00
Move methods from ServiceExt
to RoutingDsl
(#160)
Previously, on `main`, this wouldn't compile: ```rust let app = route("/", get(handler)) .layer( ServiceBuilder::new() .timeout(Duration::from_secs(10)) .into_inner(), ) .handle_error(...) .route(...); // <-- doesn't work ``` That is because `handle_error` would be `axum::service::ServiceExt::handle_error` which returns `HandleError<_, _, _, HandleErrorFromService>` which does _not_ implement `RoutingDsl`. So you couldn't call `route`. This was caused by https://github.com/tokio-rs/axum/pull/120. Basically `handle_error` when called on a `RoutingDsl`, the resulting service should also implement `RoutingDsl`, but if called on another random service it should _not_ implement `RoutingDsl`. I don't think thats possible by having `handle_error` on `ServiceExt` which is implemented for any service, since all axum routers are also services by design. This resolves the issue by removing `ServiceExt` and moving its methods to `RoutingDsl`. Then we have more tight control over what has a `handle_error` method. `service::OnMethod` now also has a `handle_error` so you can still handle errors from random services, by doing `service::any(svc).handle_error(...)`.
This commit is contained in:
parent
9b3f3c9bdf
commit
8013165908
12 changed files with 346 additions and 297 deletions
|
@ -9,8 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
|
|
||||||
- Make `FromRequest` default to being generic over `body::Body` ([#146](https://github.com/tokio-rs/axum/pull/146))
|
- Make `FromRequest` default to being generic over `body::Body` ([#146](https://github.com/tokio-rs/axum/pull/146))
|
||||||
- Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153))
|
- Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153))
|
||||||
- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156))
|
|
||||||
- Add `RoutingDsl::or` for combining routes ([#108](https://github.com/tokio-rs/axum/pull/108))
|
- Add `RoutingDsl::or` for combining routes ([#108](https://github.com/tokio-rs/axum/pull/108))
|
||||||
|
- Add `handle_error` to `service::OnMethod` ([#160](https://github.com/tokio-rs/axum/pull/160))
|
||||||
|
|
||||||
## Breaking changes
|
## Breaking changes
|
||||||
|
|
||||||
|
@ -24,9 +24,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- Ensure a `HandleError` service created from `ServiceExt::handle_error`
|
- Ensure a `HandleError` service created from `ServiceExt::handle_error`
|
||||||
_does not_ implement `RoutingDsl` as that could lead to confusing routing
|
_does not_ implement `RoutingDsl` as that could lead to confusing routing
|
||||||
behavior ([#120](https://github.com/tokio-rs/axum/pull/120))
|
behavior ([#120](https://github.com/tokio-rs/axum/pull/120))
|
||||||
|
- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156))
|
||||||
- Implement `routing::MethodFilter` via [`bitflags`](https://crates.io/crates/bitflags)
|
- Implement `routing::MethodFilter` via [`bitflags`](https://crates.io/crates/bitflags)
|
||||||
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
|
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
|
||||||
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
|
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
|
||||||
|
- `ServiceExt` has been removed and its methods have been moved to `RoutingDsl` ([#160](https://github.com/tokio-rs/axum/pull/160))
|
||||||
- These future types have been moved
|
- These future types have been moved
|
||||||
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
|
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
|
||||||
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))
|
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))
|
||||||
|
|
|
@ -12,7 +12,6 @@ use axum::{
|
||||||
prelude::*,
|
prelude::*,
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::BoxRoute,
|
routing::BoxRoute,
|
||||||
service::ServiceExt,
|
|
||||||
};
|
};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
//! cargo run --example sse --features=headers
|
//! cargo run --example sse --features=headers
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use axum::{extract::TypedHeader, prelude::*, routing::nest, service::ServiceExt, sse::Event};
|
use axum::{extract::TypedHeader, prelude::*, routing::nest, sse::Event};
|
||||||
use futures::stream::{self, Stream};
|
use futures::stream::{self, Stream};
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use std::{convert::Infallible, net::SocketAddr, time::Duration};
|
use std::{convert::Infallible, net::SocketAddr, time::Duration};
|
||||||
|
@ -19,22 +19,19 @@ async fn main() {
|
||||||
}
|
}
|
||||||
tracing_subscriber::fmt::init();
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
let static_files_service =
|
||||||
|
axum::service::get(ServeDir::new("examples/sse").append_index_html_on_directories(true))
|
||||||
|
.handle_error(|error: std::io::Error| {
|
||||||
|
Ok::<_, std::convert::Infallible>((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Unhandled internal error: {}", error),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
|
||||||
// build our application with a route
|
// build our application with a route
|
||||||
let app = nest(
|
let app = nest("/", static_files_service)
|
||||||
"/",
|
.route("/sse", axum::sse::sse(make_stream))
|
||||||
axum::service::get(
|
.layer(TraceLayer::new_for_http());
|
||||||
ServeDir::new("examples/sse")
|
|
||||||
.append_index_html_on_directories(true)
|
|
||||||
.handle_error(|error: std::io::Error| {
|
|
||||||
Ok::<_, std::convert::Infallible>((
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Unhandled internal error: {}", error),
|
|
||||||
))
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.route("/sse", axum::sse::sse(make_stream))
|
|
||||||
.layer(TraceLayer::new_for_http());
|
|
||||||
|
|
||||||
// run it
|
// run it
|
||||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
//! cargo run --example static_file_server
|
//! cargo run --example static_file_server
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use axum::{prelude::*, routing::nest, service::ServiceExt};
|
use axum::{prelude::*, routing::nest};
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use tower_http::{services::ServeDir, trace::TraceLayer};
|
use tower_http::{services::ServeDir, trace::TraceLayer};
|
||||||
|
@ -19,12 +19,12 @@ async fn main() {
|
||||||
|
|
||||||
let app = nest(
|
let app = nest(
|
||||||
"/static",
|
"/static",
|
||||||
axum::service::get(ServeDir::new(".").handle_error(|error: std::io::Error| {
|
axum::service::get(ServeDir::new(".")).handle_error(|error: std::io::Error| {
|
||||||
Ok::<_, std::convert::Infallible>((
|
Ok::<_, std::convert::Infallible>((
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
format!("Unhandled internal error: {}", error),
|
format!("Unhandled internal error: {}", error),
|
||||||
))
|
))
|
||||||
})),
|
}),
|
||||||
)
|
)
|
||||||
.layer(TraceLayer::new_for_http());
|
.layer(TraceLayer::new_for_http());
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@ use axum::{
|
||||||
extract::{Extension, Json, Path, Query},
|
extract::{Extension, Json, Path, Query},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
service::ServiceExt,
|
|
||||||
};
|
};
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
|
@ -14,7 +14,6 @@ use axum::{
|
||||||
prelude::*,
|
prelude::*,
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::nest,
|
routing::nest,
|
||||||
service::ServiceExt,
|
|
||||||
};
|
};
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
@ -35,15 +34,14 @@ async fn main() {
|
||||||
let app = nest(
|
let app = nest(
|
||||||
"/",
|
"/",
|
||||||
axum::service::get(
|
axum::service::get(
|
||||||
ServeDir::new("examples/websocket")
|
ServeDir::new("examples/websocket").append_index_html_on_directories(true),
|
||||||
.append_index_html_on_directories(true)
|
)
|
||||||
.handle_error(|error: std::io::Error| {
|
.handle_error(|error: std::io::Error| {
|
||||||
Ok::<_, std::convert::Infallible>((
|
Ok::<_, std::convert::Infallible>((
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
format!("Unhandled internal error: {}", error),
|
format!("Unhandled internal error: {}", error),
|
||||||
))
|
))
|
||||||
}),
|
}),
|
||||||
),
|
|
||||||
)
|
)
|
||||||
// routes are matched from bottom to top, so we have to put `nest` at the
|
// routes are matched from bottom to top, so we have to put `nest` at the
|
||||||
// top since it matches all routes
|
// top since it matches all routes
|
||||||
|
|
|
@ -366,9 +366,9 @@ impl<S, T> Layered<S, T> {
|
||||||
/// This is used to convert errors to responses rather than simply
|
/// This is used to convert errors to responses rather than simply
|
||||||
/// terminating the connection.
|
/// terminating the connection.
|
||||||
///
|
///
|
||||||
/// It works similarly to [`routing::Layered::handle_error`]. See that for more details.
|
/// It works similarly to [`routing::RoutingDsl::handle_error`]. See that for more details.
|
||||||
///
|
///
|
||||||
/// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error
|
/// [`routing::RoutingDsl::handle_error`]: crate::routing::RoutingDsl::handle_error
|
||||||
pub fn handle_error<F, ReqBody, ResBody, Res, E>(
|
pub fn handle_error<F, ReqBody, ResBody, Res, E>(
|
||||||
self,
|
self,
|
||||||
f: F,
|
f: F,
|
||||||
|
|
|
@ -493,7 +493,7 @@
|
||||||
//! return `Result<T, E>` where `T` implements
|
//! return `Result<T, E>` where `T` implements
|
||||||
//! [`IntoResponse`](response::IntoResponse).
|
//! [`IntoResponse`](response::IntoResponse).
|
||||||
//!
|
//!
|
||||||
//! See [`routing::Layered::handle_error`] for more details.
|
//! See [`routing::RoutingDsl::handle_error`] for more details.
|
||||||
//!
|
//!
|
||||||
//! ## Applying multiple middleware
|
//! ## Applying multiple middleware
|
||||||
//!
|
//!
|
||||||
|
|
183
src/routing.rs
183
src/routing.rs
|
@ -6,7 +6,7 @@ use crate::{
|
||||||
buffer::MpscBuffer,
|
buffer::MpscBuffer,
|
||||||
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
service::HandleErrorFromRouter,
|
service::{HandleError, HandleErrorFromRouter},
|
||||||
util::ByteStr,
|
util::ByteStr,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
@ -348,6 +348,94 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized {
|
||||||
second: other,
|
second: other,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Handle errors services in this router might produce, by mapping them to
|
||||||
|
/// responses.
|
||||||
|
///
|
||||||
|
/// Unhandled errors will close the connection without sending a response.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use axum::{http::StatusCode, prelude::*};
|
||||||
|
/// use tower::{BoxError, timeout::TimeoutLayer};
|
||||||
|
/// use std::{time::Duration, convert::Infallible};
|
||||||
|
///
|
||||||
|
/// // This router can never fail, since handlers can never fail.
|
||||||
|
/// let app = route("/", get(|| async {}));
|
||||||
|
///
|
||||||
|
/// // Now the router can fail since the `tower::timeout::Timeout`
|
||||||
|
/// // middleware will return an error if the timeout elapses.
|
||||||
|
/// let app = app.layer(TimeoutLayer::new(Duration::from_secs(10)));
|
||||||
|
///
|
||||||
|
/// // With `handle_error` we can handle errors `Timeout` might produce.
|
||||||
|
/// // Our router now cannot fail, that is its error type is `Infallible`.
|
||||||
|
/// let app = app.handle_error(|error: BoxError| {
|
||||||
|
/// if error.is::<tower::timeout::error::Elapsed>() {
|
||||||
|
/// Ok::<_, Infallible>((
|
||||||
|
/// StatusCode::REQUEST_TIMEOUT,
|
||||||
|
/// "request took too long to handle".to_string(),
|
||||||
|
/// ))
|
||||||
|
/// } else {
|
||||||
|
/// Ok::<_, Infallible>((
|
||||||
|
/// StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
/// format!("Unhandled error: {}", error),
|
||||||
|
/// ))
|
||||||
|
/// }
|
||||||
|
/// });
|
||||||
|
/// # async {
|
||||||
|
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||||
|
/// # };
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// You can return `Err(_)` from the closure if you don't wish to handle
|
||||||
|
/// some errors:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use axum::{http::StatusCode, prelude::*};
|
||||||
|
/// use tower::{BoxError, timeout::TimeoutLayer};
|
||||||
|
/// use std::time::Duration;
|
||||||
|
///
|
||||||
|
/// let app = route("/", get(|| async {}))
|
||||||
|
/// .layer(TimeoutLayer::new(Duration::from_secs(10)))
|
||||||
|
/// .handle_error(|error: BoxError| {
|
||||||
|
/// if error.is::<tower::timeout::error::Elapsed>() {
|
||||||
|
/// Ok((
|
||||||
|
/// StatusCode::REQUEST_TIMEOUT,
|
||||||
|
/// "request took too long to handle".to_string(),
|
||||||
|
/// ))
|
||||||
|
/// } else {
|
||||||
|
/// // return the error as is
|
||||||
|
/// Err(error)
|
||||||
|
/// }
|
||||||
|
/// });
|
||||||
|
/// # async {
|
||||||
|
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||||
|
/// # };
|
||||||
|
/// ```
|
||||||
|
fn handle_error<ReqBody, ResBody, F, Res, E>(
|
||||||
|
self,
|
||||||
|
f: F,
|
||||||
|
) -> HandleError<Self, F, ReqBody, HandleErrorFromRouter>
|
||||||
|
where
|
||||||
|
Self: Service<Request<ReqBody>, Response = Response<ResBody>>,
|
||||||
|
F: FnOnce(Self::Error) -> Result<Res, E>,
|
||||||
|
Res: IntoResponse,
|
||||||
|
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
||||||
|
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
HandleError::new(self, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check that your service cannot fail.
|
||||||
|
///
|
||||||
|
/// That is, its error type is [`Infallible`].
|
||||||
|
fn check_infallible<ReqBody>(self) -> Self
|
||||||
|
where
|
||||||
|
Self: Service<Request<ReqBody>, Error = Infallible>,
|
||||||
|
{
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, F> RoutingDsl for Route<S, F> {}
|
impl<S, F> RoutingDsl for Route<S, F> {}
|
||||||
|
@ -646,97 +734,6 @@ impl<S> RoutingDsl for Layered<S> {}
|
||||||
|
|
||||||
impl<S> crate::sealed::Sealed for Layered<S> {}
|
impl<S> crate::sealed::Sealed for Layered<S> {}
|
||||||
|
|
||||||
impl<S> Layered<S> {
|
|
||||||
/// Create a new [`Layered`] service where errors will be handled using the
|
|
||||||
/// given closure.
|
|
||||||
///
|
|
||||||
/// This is used to convert errors to responses rather than simply
|
|
||||||
/// terminating the connection.
|
|
||||||
///
|
|
||||||
/// That can be done using `handle_error` like so:
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// use axum::prelude::*;
|
|
||||||
/// use http::StatusCode;
|
|
||||||
/// use tower::{BoxError, timeout::TimeoutLayer};
|
|
||||||
/// use std::{convert::Infallible, time::Duration};
|
|
||||||
///
|
|
||||||
/// async fn handler() { /* ... */ }
|
|
||||||
///
|
|
||||||
/// // `Timeout` will fail with `BoxError` if the timeout elapses...
|
|
||||||
/// let layered_app = route("/", get(handler))
|
|
||||||
/// .layer(TimeoutLayer::new(Duration::from_secs(30)));
|
|
||||||
///
|
|
||||||
/// // ...so we should handle that error
|
|
||||||
/// let with_errors_handled = layered_app.handle_error(|error: BoxError| {
|
|
||||||
/// if error.is::<tower::timeout::error::Elapsed>() {
|
|
||||||
/// Ok::<_, Infallible>((
|
|
||||||
/// StatusCode::REQUEST_TIMEOUT,
|
|
||||||
/// "request took too long".to_string(),
|
|
||||||
/// ))
|
|
||||||
/// } else {
|
|
||||||
/// Ok::<_, Infallible>((
|
|
||||||
/// StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
/// format!("Unhandled internal error: {}", error),
|
|
||||||
/// ))
|
|
||||||
/// }
|
|
||||||
/// });
|
|
||||||
/// # async {
|
|
||||||
/// # axum::Server::bind(&"".parse().unwrap())
|
|
||||||
/// # .serve(with_errors_handled.into_make_service())
|
|
||||||
/// # .await
|
|
||||||
/// # .unwrap();
|
|
||||||
/// # };
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// The closure must return `Result<T, E>` where `T` implements [`IntoResponse`].
|
|
||||||
///
|
|
||||||
/// You can also return `Err(_)` if you don't wish to handle the error:
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// use axum::prelude::*;
|
|
||||||
/// use http::StatusCode;
|
|
||||||
/// use tower::{BoxError, timeout::TimeoutLayer};
|
|
||||||
/// use std::time::Duration;
|
|
||||||
///
|
|
||||||
/// async fn handler() { /* ... */ }
|
|
||||||
///
|
|
||||||
/// let layered_app = route("/", get(handler))
|
|
||||||
/// .layer(TimeoutLayer::new(Duration::from_secs(30)));
|
|
||||||
///
|
|
||||||
/// let with_errors_handled = layered_app.handle_error(|error: BoxError| {
|
|
||||||
/// if error.is::<tower::timeout::error::Elapsed>() {
|
|
||||||
/// Ok((
|
|
||||||
/// StatusCode::REQUEST_TIMEOUT,
|
|
||||||
/// "request took too long".to_string(),
|
|
||||||
/// ))
|
|
||||||
/// } else {
|
|
||||||
/// // keep the error as is
|
|
||||||
/// Err(error)
|
|
||||||
/// }
|
|
||||||
/// });
|
|
||||||
/// # async {
|
|
||||||
/// # axum::Server::bind(&"".parse().unwrap())
|
|
||||||
/// # .serve(with_errors_handled.into_make_service())
|
|
||||||
/// # .await
|
|
||||||
/// # .unwrap();
|
|
||||||
/// # };
|
|
||||||
/// ```
|
|
||||||
pub fn handle_error<F, ReqBody, ResBody, Res, E>(
|
|
||||||
self,
|
|
||||||
f: F,
|
|
||||||
) -> crate::service::HandleError<S, F, ReqBody, HandleErrorFromRouter>
|
|
||||||
where
|
|
||||||
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
|
|
||||||
F: FnOnce(S::Error) -> Result<Res, E>,
|
|
||||||
Res: IntoResponse,
|
|
||||||
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
|
||||||
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
crate::service::HandleError::new(self.inner, f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, R> Service<R> for Layered<S>
|
impl<S, R> Service<R> for Layered<S>
|
||||||
where
|
where
|
||||||
S: Service<R>,
|
S: Service<R>,
|
||||||
|
@ -809,7 +806,7 @@ where
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use axum::{
|
/// use axum::{
|
||||||
/// routing::nest, service::{get, ServiceExt}, prelude::*,
|
/// routing::nest, service::get, prelude::*,
|
||||||
/// };
|
/// };
|
||||||
/// use tower_http::services::ServeDir;
|
/// use tower_http::services::ServeDir;
|
||||||
///
|
///
|
||||||
|
|
|
@ -94,7 +94,6 @@ use crate::{
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http::{Request, Response};
|
use http::{Request, Response};
|
||||||
use std::{
|
use std::{
|
||||||
convert::Infallible,
|
|
||||||
fmt,
|
fmt,
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
|
@ -428,6 +427,26 @@ impl<S, F> OnMethod<S, F> {
|
||||||
fallback: self,
|
fallback: self,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Handle errors this service might produce, by mapping them to responses.
|
||||||
|
///
|
||||||
|
/// Unhandled errors will close the connection without sending a response.
|
||||||
|
///
|
||||||
|
/// Works similarly to [`RoutingDsl::handle_error`]. See that for more
|
||||||
|
/// details.
|
||||||
|
///
|
||||||
|
/// [`RoutingDsl::handle_error`]: crate::routing::RoutingDsl::handle_error
|
||||||
|
pub fn handle_error<ReqBody, H, Res, E>(
|
||||||
|
self,
|
||||||
|
f: H,
|
||||||
|
) -> HandleError<Self, H, ReqBody, HandleErrorFromService>
|
||||||
|
where
|
||||||
|
Self: Service<Request<ReqBody>, Response = Response<BoxBody>>,
|
||||||
|
H: FnOnce(<Self as Service<Request<ReqBody>>>::Error) -> Result<Res, E>,
|
||||||
|
Res: IntoResponse,
|
||||||
|
{
|
||||||
|
HandleError::new(self, f)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean
|
// this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean
|
||||||
|
@ -462,7 +481,7 @@ where
|
||||||
///
|
///
|
||||||
/// Created with
|
/// Created with
|
||||||
/// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or
|
/// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or
|
||||||
/// [`routing::Layered::handle_error`](crate::routing::Layered::handle_error).
|
/// [`routing::RoutingDsl::handle_error`](crate::routing::RoutingDsl::handle_error).
|
||||||
/// See those methods for more details.
|
/// See those methods for more details.
|
||||||
pub struct HandleError<S, F, B, T> {
|
pub struct HandleError<S, F, B, T> {
|
||||||
inner: S,
|
inner: S,
|
||||||
|
@ -542,75 +561,6 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extension trait that adds additional methods to [`Service`].
|
|
||||||
pub trait ServiceExt<ReqBody, ResBody>:
|
|
||||||
Service<Request<ReqBody>, Response = Response<ResBody>>
|
|
||||||
{
|
|
||||||
/// Handle errors from a service.
|
|
||||||
///
|
|
||||||
/// `handle_error` takes a closure that will map errors from the service
|
|
||||||
/// into responses. The closure's return type must be `Result<T, E>` where
|
|
||||||
/// `T` implements [`IntoResponse`](crate::response::IntoResponse).
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust,no_run
|
|
||||||
/// use axum::{service::{self, ServiceExt}, prelude::*};
|
|
||||||
/// use http::{Response, StatusCode};
|
|
||||||
/// use tower::{service_fn, BoxError};
|
|
||||||
/// use std::convert::Infallible;
|
|
||||||
///
|
|
||||||
/// // A service that might fail with `std::io::Error`
|
|
||||||
/// let service = service_fn(|_: Request<Body>| async {
|
|
||||||
/// let res = Response::new(Body::empty());
|
|
||||||
/// Ok::<_, std::io::Error>(res)
|
|
||||||
/// });
|
|
||||||
///
|
|
||||||
/// let app = route(
|
|
||||||
/// "/",
|
|
||||||
/// service.handle_error(|error: std::io::Error| {
|
|
||||||
/// Ok::<_, Infallible>((
|
|
||||||
/// StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
/// error.to_string(),
|
|
||||||
/// ))
|
|
||||||
/// }),
|
|
||||||
/// );
|
|
||||||
/// #
|
|
||||||
/// # async {
|
|
||||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
|
||||||
/// # };
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// It works similarly to [`routing::Layered::handle_error`]. See that for more details.
|
|
||||||
///
|
|
||||||
/// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error
|
|
||||||
fn handle_error<F, Res, E>(self, f: F) -> HandleError<Self, F, ReqBody, HandleErrorFromService>
|
|
||||||
where
|
|
||||||
Self: Sized,
|
|
||||||
F: FnOnce(Self::Error) -> Result<Res, E>,
|
|
||||||
Res: IntoResponse,
|
|
||||||
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
|
||||||
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
HandleError::new(self, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check that your service cannot fail.
|
|
||||||
///
|
|
||||||
/// That is, its error type is [`Infallible`].
|
|
||||||
fn check_infallible(self) -> Self
|
|
||||||
where
|
|
||||||
Self: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible> + Sized,
|
|
||||||
{
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, ReqBody, ResBody> ServiceExt<ReqBody, ResBody> for S where
|
|
||||||
S: Service<Request<ReqBody>, Response = Response<ResBody>>
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A [`Service`] that boxes response bodies.
|
/// A [`Service`] that boxes response bodies.
|
||||||
pub struct BoxResponseBody<S, B> {
|
pub struct BoxResponseBody<S, B> {
|
||||||
inner: S,
|
inner: S,
|
||||||
|
|
100
src/tests.rs
100
src/tests.rs
|
@ -16,9 +16,9 @@ use std::{
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
use tower::{make::Shared, service_fn, BoxError, Service, ServiceBuilder};
|
use tower::{make::Shared, service_fn, BoxError, Service};
|
||||||
use tower_http::{compression::CompressionLayer, trace::TraceLayer};
|
|
||||||
|
|
||||||
|
mod handle_error;
|
||||||
mod nest;
|
mod nest;
|
||||||
mod or;
|
mod or;
|
||||||
|
|
||||||
|
@ -326,53 +326,6 @@ async fn boxing() {
|
||||||
assert_eq!(res.text().await.unwrap(), "hi from POST");
|
assert_eq!(res.text().await.unwrap(), "hi from POST");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn service_handlers() {
|
|
||||||
use crate::service::ServiceExt as _;
|
|
||||||
use tower_http::services::ServeFile;
|
|
||||||
|
|
||||||
let app = route(
|
|
||||||
"/echo",
|
|
||||||
service::post(
|
|
||||||
service_fn(|req: Request<Body>| async move {
|
|
||||||
Ok::<_, BoxError>(Response::new(req.into_body()))
|
|
||||||
})
|
|
||||||
.handle_error(|_error: BoxError| Ok(StatusCode::INTERNAL_SERVER_ERROR)),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/static/Cargo.toml",
|
|
||||||
service::on(
|
|
||||||
MethodFilter::GET,
|
|
||||||
ServeFile::new("Cargo.toml").handle_error(|error: std::io::Error| {
|
|
||||||
Ok::<_, Infallible>((StatusCode::INTERNAL_SERVER_ERROR, error.to_string()))
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
|
|
||||||
let addr = run_in_background(app).await;
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
|
|
||||||
let res = client
|
|
||||||
.post(format!("http://{}/echo", addr))
|
|
||||||
.body("foobar")
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(res.status(), StatusCode::OK);
|
|
||||||
assert_eq!(res.text().await.unwrap(), "foobar");
|
|
||||||
|
|
||||||
let res = client
|
|
||||||
.get(format!("http://{}/static/Cargo.toml", addr))
|
|
||||||
.body("foobar")
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(res.status(), StatusCode::OK);
|
|
||||||
assert!(res.text().await.unwrap().contains("edition ="));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn routing_between_services() {
|
async fn routing_between_services() {
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
|
@ -466,55 +419,6 @@ async fn middleware_on_single_route() {
|
||||||
assert_eq!(body, "Hello, World!");
|
assert_eq!(body, "Hello, World!");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn handling_errors_from_layered_single_routes() {
|
|
||||||
async fn handle(_req: Request<Body>) -> &'static str {
|
|
||||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
|
||||||
""
|
|
||||||
}
|
|
||||||
|
|
||||||
let app = route(
|
|
||||||
"/",
|
|
||||||
get(handle
|
|
||||||
.layer(
|
|
||||||
ServiceBuilder::new()
|
|
||||||
.timeout(Duration::from_millis(100))
|
|
||||||
.layer(TraceLayer::new_for_http())
|
|
||||||
.into_inner(),
|
|
||||||
)
|
|
||||||
.handle_error(|_error: BoxError| {
|
|
||||||
Ok::<_, Infallible>(StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
})),
|
|
||||||
);
|
|
||||||
|
|
||||||
let addr = run_in_background(app).await;
|
|
||||||
|
|
||||||
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
|
|
||||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn layer_on_whole_router() {
|
|
||||||
async fn handle(_req: Request<Body>) -> &'static str {
|
|
||||||
tokio::time::sleep(Duration::from_secs(10)).await;
|
|
||||||
""
|
|
||||||
}
|
|
||||||
|
|
||||||
let app = route("/", get(handle))
|
|
||||||
.layer(
|
|
||||||
ServiceBuilder::new()
|
|
||||||
.layer(CompressionLayer::new())
|
|
||||||
.timeout(Duration::from_millis(100))
|
|
||||||
.into_inner(),
|
|
||||||
)
|
|
||||||
.handle_error(|_err: BoxError| Ok::<_, Infallible>(StatusCode::INTERNAL_SERVER_ERROR));
|
|
||||||
|
|
||||||
let addr = run_in_background(app).await;
|
|
||||||
|
|
||||||
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
|
|
||||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[cfg(feature = "header")]
|
#[cfg(feature = "header")]
|
||||||
async fn typed_header() {
|
async fn typed_header() {
|
||||||
|
|
203
src/tests/handle_error.rs
Normal file
203
src/tests/handle_error.rs
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
use super::*;
|
||||||
|
use futures_util::future::{pending, ready};
|
||||||
|
use tower::{timeout::TimeoutLayer, MakeService};
|
||||||
|
|
||||||
|
async fn unit() {}
|
||||||
|
|
||||||
|
async fn forever() {
|
||||||
|
pending().await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn timeout() -> TimeoutLayer {
|
||||||
|
TimeoutLayer::new(Duration::from_millis(10))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Svc;
|
||||||
|
|
||||||
|
impl<R> Service<R> for Svc {
|
||||||
|
type Response = Response<Body>;
|
||||||
|
type Error = hyper::Error;
|
||||||
|
type Future = Ready<Result<Self::Response, Self::Error>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, _req: R) -> Self::Future {
|
||||||
|
ready(Ok(Response::new(Body::empty())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_make_svc<M, R, T, E>(_make_svc: M)
|
||||||
|
where
|
||||||
|
M: MakeService<(), R, Response = T, Error = E>,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_error<E>(_: E) -> Result<StatusCode, Infallible> {
|
||||||
|
Ok(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handler() {
|
||||||
|
let app = route(
|
||||||
|
"/",
|
||||||
|
get(forever
|
||||||
|
.layer(timeout())
|
||||||
|
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT))),
|
||||||
|
);
|
||||||
|
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handler_multiple_methods_first() {
|
||||||
|
let app = route(
|
||||||
|
"/",
|
||||||
|
get(forever
|
||||||
|
.layer(timeout())
|
||||||
|
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)))
|
||||||
|
.post(unit),
|
||||||
|
);
|
||||||
|
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handler_multiple_methods_middle() {
|
||||||
|
let app = route(
|
||||||
|
"/",
|
||||||
|
delete(unit)
|
||||||
|
.get(
|
||||||
|
forever
|
||||||
|
.layer(timeout())
|
||||||
|
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)),
|
||||||
|
)
|
||||||
|
.post(unit),
|
||||||
|
);
|
||||||
|
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handler_multiple_methods_last() {
|
||||||
|
let app = route(
|
||||||
|
"/",
|
||||||
|
delete(unit).get(
|
||||||
|
forever
|
||||||
|
.layer(timeout())
|
||||||
|
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn service_propagates_errors() {
|
||||||
|
let app = route::<_, Body>("/echo", service::post(Svc));
|
||||||
|
|
||||||
|
check_make_svc::<_, _, _, hyper::Error>(app.into_make_service());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn service_nested_propagates_errors() {
|
||||||
|
let app = route::<_, Body>("/echo", nest("/foo", service::post(Svc)));
|
||||||
|
|
||||||
|
check_make_svc::<_, _, _, hyper::Error>(app.into_make_service());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn service_handle_on_method() {
|
||||||
|
let app = route::<_, Body>(
|
||||||
|
"/echo",
|
||||||
|
service::get(Svc).handle_error(handle_error::<hyper::Error>),
|
||||||
|
);
|
||||||
|
|
||||||
|
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn service_handle_on_method_multiple() {
|
||||||
|
let app = route::<_, Body>(
|
||||||
|
"/echo",
|
||||||
|
service::get(Svc)
|
||||||
|
.post(Svc)
|
||||||
|
.handle_error(handle_error::<hyper::Error>),
|
||||||
|
);
|
||||||
|
|
||||||
|
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn service_handle_on_router() {
|
||||||
|
let app =
|
||||||
|
route::<_, Body>("/echo", service::get(Svc)).handle_error(handle_error::<hyper::Error>);
|
||||||
|
|
||||||
|
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn service_handle_on_router_still_impls_routing_dsl() {
|
||||||
|
let app = route::<_, Body>("/echo", service::get(Svc))
|
||||||
|
.handle_error(handle_error::<hyper::Error>)
|
||||||
|
.route("/", get(unit));
|
||||||
|
|
||||||
|
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn layered() {
|
||||||
|
let app = route::<_, Body>("/echo", get(unit))
|
||||||
|
.layer(timeout())
|
||||||
|
.handle_error(handle_error::<BoxError>);
|
||||||
|
|
||||||
|
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test] // async because of `.boxed()`
|
||||||
|
async fn layered_boxed() {
|
||||||
|
let app = route::<_, Body>("/echo", get(unit))
|
||||||
|
.layer(timeout())
|
||||||
|
.boxed()
|
||||||
|
.handle_error(handle_error::<BoxError>);
|
||||||
|
|
||||||
|
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
|
||||||
|
}
|
Loading…
Reference in a new issue