diff --git a/Cargo.toml b/Cargo.toml index 0013785b..beaeb0b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ thiserror = "1.0" tower = { version = "0.4", features = ["util", "buffer"] } tower-http = { version = "0.1", features = ["add-extension"] } regex = "1.5" +tokio = { version = "1", features = ["time"] } [dev-dependencies] hyper = { version = "0.14", features = ["full"] } @@ -39,4 +40,6 @@ features = [ "compression-full", "fs", "trace", + "redirect", + "auth", ] diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 2224028c..4460854b 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -1,47 +1,70 @@ +//! Simple in-memory key/value store showing features of tower-web. +//! +//! Run with: +//! +//! ```not_rust +//! RUST_LOG=tower_http=debug,key_value_store=trace cargo run --example key_value_store +//! ``` + use bytes::Bytes; use http::{Request, StatusCode}; use hyper::Server; use std::{ + borrow::Cow, collections::HashMap, net::SocketAddr, - sync::{Arc, Mutex}, + sync::{Arc, RwLock}, time::Duration, }; -use tower::{make::Shared, ServiceBuilder}; +use tower::{make::Shared, BoxError, ServiceBuilder}; use tower_http::{ - add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer, + add_extension::AddExtensionLayer, auth::RequireAuthorizationLayer, + compression::CompressionLayer, trace::TraceLayer, }; use tower_web::{ - body::Body, + body::{Body, BoxBody}, extract::{BytesMaxLength, Extension, UrlParams}, prelude::*, + response::IntoResponse, + routing::BoxRoute, }; #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - // build our application with some routes + // Build our application by composing routes let app = route( "/:key", - get(kv_get.layer(CompressionLayer::new())).post(kv_set), - ); + // Add compression to `kv_get` + get(kv_get.layer(CompressionLayer::new())) + // But don't compress `kv_set` + .post(kv_set), + ) + .route("/keys", get(list_keys)) + // Nest our admin routes under `/admin` + .nest("/admin", admin_routes()) + // Add middleware to all routes + .layer( + ServiceBuilder::new() + .load_shed() + .concurrency_limit(1024) + .timeout(Duration::from_secs(10)) + .layer(TraceLayer::new_for_http()) + .layer(AddExtensionLayer::new(SharedState::default())) + .into_inner(), + ) + // Handle errors from middleware + .handle_error(handle_error); - // add some middleware - let app = ServiceBuilder::new() - .timeout(Duration::from_secs(10)) - .layer(TraceLayer::new_for_http()) - .layer(AddExtensionLayer::new(SharedState::default())) - .service(app); - - // run it with hyper + // Run our app with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); let server = Server::bind(&addr).serve(Shared::new(app)); server.await.unwrap(); } -type SharedState = Arc>; +type SharedState = Arc>; #[derive(Default)] struct State { @@ -53,7 +76,7 @@ async fn kv_get( UrlParams((key,)): UrlParams<(String,)>, Extension(state): Extension, ) -> Result { - let db = &state.lock().unwrap().db; + let db = &state.read().unwrap().db; if let Some(value) = db.get(&key) { Ok(value.clone()) @@ -68,5 +91,52 @@ async fn kv_set( BytesMaxLength(value): BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb Extension(state): Extension, ) { - state.lock().unwrap().db.insert(key, value); + state.write().unwrap().db.insert(key, value); +} + +async fn list_keys(_req: Request, Extension(state): Extension) -> String { + let db = &state.read().unwrap().db; + + db.keys() + .map(|key| key.to_string()) + .collect::>() + .join("\n") +} + +fn admin_routes() -> BoxRoute { + async fn delete_all_keys(_req: Request, Extension(state): Extension) { + state.write().unwrap().db.clear(); + } + + async fn remove_key( + _req: Request, + UrlParams((key,)): UrlParams<(String,)>, + Extension(state): Extension, + ) { + state.write().unwrap().db.remove(&key); + } + + route("/keys", delete(delete_all_keys)) + .route("/key/:key", delete(remove_key)) + // Require beare auth for all admin routes + .layer(RequireAuthorizationLayer::bearer("secret-token")) + .boxed() +} + +fn handle_error(error: BoxError) -> impl IntoResponse { + if error.is::() { + return (StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out")); + } + + if error.is::() { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Cow::from("service is overloaded, try again later"), + ); + } + + ( + StatusCode::INTERNAL_SERVER_ERROR, + Cow::from(format!("Unhandled internal error: {}", error)), + ) } diff --git a/examples/static_file_server.rs b/examples/static_file_server.rs index 07f165e9..d16da746 100644 --- a/examples/static_file_server.rs +++ b/examples/static_file_server.rs @@ -3,7 +3,7 @@ use hyper::Server; use std::net::SocketAddr; use tower::make::Shared; use tower_http::{services::ServeDir, trace::TraceLayer}; -use tower_web::{prelude::*, ServiceExt}; +use tower_web::{prelude::*, service::ServiceExt}; #[tokio::main] async fn main() { diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 175ba36f..b90f6058 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -44,8 +44,8 @@ where type Rejection = QueryStringMissing; async fn from_request(req: &mut Request) -> Result { - let query = req.uri().query().ok_or(QueryStringMissing(()))?; - let value = serde_urlencoded::from_str(query).map_err(|_| QueryStringMissing(()))?; + let query = req.uri().query().ok_or(QueryStringMissing)?; + let value = serde_urlencoded::from_str(query).map_err(|_| QueryStringMissing)?; Ok(Query(value)) } } @@ -75,7 +75,7 @@ where Ok(Json(value)) } else { - Err(MissingJsonContentType(()).into_response()) + Err(MissingJsonContentType.into_response()) } } } @@ -110,7 +110,7 @@ where let value = req .extensions() .get::() - .ok_or(MissingExtension(())) + .ok_or(MissingExtension) .map(|x| x.clone())?; Ok(Extension(value)) @@ -179,10 +179,10 @@ impl FromRequest for BytesMaxLength { if let Some(length) = content_length { if length > N { - return Err(PayloadTooLarge(()).into_response()); + return Err(PayloadTooLarge.into_response()); } } else { - return Err(LengthRequired(()).into_response()); + return Err(LengthRequired.into_response()); }; let bytes = hyper::body::to_bytes(body) @@ -221,7 +221,7 @@ impl FromRequest for UrlParamsMap { let params = params.take().expect("params already taken").0; Ok(Self(params.into_iter().collect())) } else { - Err(MissingRouteParams(())) + Err(MissingRouteParams) } } } @@ -248,7 +248,7 @@ macro_rules! impl_parse_url { { params.take().expect("params already taken").0 } else { - return Err(MissingRouteParams(()).into_response()) + return Err(MissingRouteParams.into_response()) }; if let [(_, $head), $((_, $tail),)*] = &*params { @@ -268,7 +268,7 @@ macro_rules! impl_parse_url { Ok(UrlParams(($head, $($tail,)*))) } else { - return Err(MissingRouteParams(()).into_response()) + return Err(MissingRouteParams.into_response()) } } } @@ -283,7 +283,7 @@ fn take_body(req: &mut Request) -> Result { struct BodyAlreadyTakenExt; if req.extensions_mut().insert(BodyAlreadyTakenExt).is_some() { - Err(BodyAlreadyTaken(())) + Err(BodyAlreadyTaken) } else { let body = std::mem::take(req.body_mut()); Ok(body) diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index 49c48dd4..08aeaac9 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -8,11 +8,12 @@ macro_rules! define_rejection { #[status = $status:ident] #[body = $body:expr] $(#[$m:meta])* - pub struct $name:ident (()); + pub struct $name:ident; ) => { $(#[$m])* #[derive(Debug)] - pub struct $name(pub(super) ()); + #[non_exhaustive] + pub struct $name; impl IntoResponse for $name { fn into_response(self) -> http::Response { @@ -57,7 +58,7 @@ define_rejection! { #[status = BAD_REQUEST] #[body = "Query string was invalid or missing"] /// Rejection type for [`Query`](super::Query). - pub struct QueryStringMissing(()); + pub struct QueryStringMissing; } define_rejection! { @@ -72,7 +73,7 @@ define_rejection! { #[body = "Expected request with `Content-Type: application/json`"] /// Rejection type for [`Json`](super::Json) used if the `Content-Type` /// header is missing. - pub struct MissingJsonContentType(()); + pub struct MissingJsonContentType; } define_rejection! { @@ -80,7 +81,7 @@ define_rejection! { #[body = "Missing request extension"] /// Rejection type for [`Extension`](super::Extension) if an expected /// request extension was not found. - pub struct MissingExtension(()); + pub struct MissingExtension; } define_rejection! { @@ -104,7 +105,7 @@ define_rejection! { #[body = "Request payload is too large"] /// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the /// request body is too large. - pub struct PayloadTooLarge(()); + pub struct PayloadTooLarge; } define_rejection! { @@ -112,7 +113,7 @@ define_rejection! { #[body = "Content length header is required"] /// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the /// request is missing the `Content-Length` header or it is invalid. - pub struct LengthRequired(()); + pub struct LengthRequired; } define_rejection! { @@ -121,7 +122,7 @@ define_rejection! { /// Rejection type for [`UrlParamsMap`](super::UrlParamsMap) and /// [`UrlParams`](super::UrlParams) if you try and extract the URL params /// more than once. - pub struct MissingRouteParams(()); + pub struct MissingRouteParams; } define_rejection! { @@ -129,7 +130,7 @@ define_rejection! { #[body = "Cannot have two request body extractors for a single handler"] /// Rejection type used if you try and extract the request body more than /// once. - pub struct BodyAlreadyTaken(()); + pub struct BodyAlreadyTaken; } /// Rejection type for [`UrlParams`](super::UrlParams) if the capture route diff --git a/src/handler/future.rs b/src/handler/future.rs new file mode 100644 index 00000000..0556c539 --- /dev/null +++ b/src/handler/future.rs @@ -0,0 +1,11 @@ +//! Handler future types. + +use http::Response; +use std::convert::Infallible; +use crate::body::BoxBody; + +opaque_future! { + /// The response future for [`IntoService`](super::IntoService). + pub type IntoServiceFuture = + futures_util::future::BoxFuture<'static, Result, Infallible>>; +} diff --git a/src/handler.rs b/src/handler/mod.rs similarity index 88% rename from src/handler.rs rename to src/handler/mod.rs index ab2b8c16..9be1468a 100644 --- a/src/handler.rs +++ b/src/handler/mod.rs @@ -1,4 +1,46 @@ //! Async functions that can be used to handle requests. +//! +//! # What is a handler? +//! +//! In tower-web a "handler" is an async function that accepts a request and +//! produces a response. Handler functions must take +//! `http::Request` as they first argument and return +//! something that implements [`IntoResponse`]. +//! +//! Additionally handlers can use ["extractors"](crate::extract) to extract data +//! from incoming requests. +//! +//! # Example +//! +//! Some examples of handlers: +//! +//! ```rust +//! use tower_web::prelude::*; +//! use bytes::Bytes; +//! use http::StatusCode; +//! +//! // Handlers must take `Request` as the first argument and must return +//! // something that implements `IntoResponse`, which `()` does +//! async fn unit_handler(request: Request) {} +//! +//! // `String` also implements `IntoResponse` +//! async fn string_handler(request: Request) -> String { +//! "Hello, World!".to_string() +//! } +//! +//! // Handler that buffers the request body and returns it if it is valid UTF-8 +//! async fn buffer_body(request: Request, body: Bytes) -> Result { +//! if let Ok(string) = String::from_utf8(body.to_vec()) { +//! Ok(string) +//! } else { +//! Err(StatusCode::BAD_REQUEST) +//! } +//! } +//! ``` +//! +//! For more details on generating responses see the +//! [`response`](crate::response) module and for more details on extractors see +//! the [`extract`](crate::extract) module. use crate::{ body::{Body, BoxBody}, @@ -8,8 +50,8 @@ use crate::{ service::HandleError, }; use async_trait::async_trait; +use futures_util::future::Either; use bytes::Bytes; -use futures_util::future; use http::{Request, Response}; use std::{ convert::Infallible, @@ -20,6 +62,8 @@ use std::{ }; use tower::{BoxError, Layer, Service, ServiceExt}; +pub mod future; + /// Route requests to the given handler regardless of the HTTP method of the /// request. /// @@ -175,37 +219,7 @@ mod sealed { /// You shouldn't need to depend on this trait directly. It is automatically /// implemented to closures of the right types. /// -/// # Example -/// -/// Some examples of handlers: -/// -/// ```rust -/// use tower_web::prelude::*; -/// use bytes::Bytes; -/// use http::StatusCode; -/// -/// // Handlers must take `Request` as the first argument and must return -/// // something that implements `IntoResponse`, which `()` does -/// async fn unit_handler(request: Request) {} -/// -/// // `String` also implements `IntoResponse` -/// async fn string_handler(request: Request) -> String { -/// "Hello, World!".to_string() -/// } -/// -/// // Handler the buffers the request body and returns it if it is valid UTF-8 -/// async fn buffer_body(request: Request, body: Bytes) -> Result { -/// if let Ok(string) = String::from_utf8(body.to_vec()) { -/// Ok(string) -/// } else { -/// Err(StatusCode::BAD_REQUEST) -/// } -/// } -/// ``` -/// -/// For more details on generating responses see the -/// [`response`](crate::response) module and for more details on extractors see -/// the [`extract`](crate::extract) module. +/// See the [module docs](crate::handler) for more details. #[async_trait] pub trait Handler: Sized { // This seals the trait. We cannot use the regular "sealed super trait" approach @@ -218,10 +232,19 @@ pub trait Handler: Sized { /// Apply a [`tower::Layer`] to the handler. /// + /// All requests to the handler will be processed by the layer's + /// corresponding middleware. + /// + /// This can be used to add additional processing to a request for a single + /// handler. + /// + /// Note this differes from [`routing::Layered`](crate::routing::Layered) + /// which adds a middleware to a group of routes. + /// /// # Example /// /// Adding the [`tower::limit::ConcurrencyLimit`] middleware to a handler - /// can be done with [`tower::limit::ConcurrencyLimitLayer`]: + /// can be done like so: /// /// ```rust /// use tower_web::prelude::*; @@ -304,7 +327,7 @@ impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, /// A [`Service`] created from a [`Handler`] by applying a Tower middleware. /// -/// Created with [`Handler::layer`]. +/// Created with [`Handler::layer`]. See that method for more details. pub struct Layered { svc: S, _input: PhantomData T>, @@ -332,7 +355,6 @@ where impl Handler for Layered where S: Service, Response = Response> + Send, - // S::Response: IntoResponse, S::Error: IntoResponse, S::Future: Send, B: http_body::Body + Send + Sync + 'static, @@ -456,7 +478,7 @@ where { type Response = Response; type Error = Infallible; - type Future = IntoServiceFuture; + type Future = future::IntoServiceFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { // `IntoService` can only be constructed from async functions which are always ready, or from @@ -471,17 +493,12 @@ where let res = Handler::call(handler, req).await; Ok(res) }); - IntoServiceFuture(future) + future::IntoServiceFuture(future) } } -opaque_future! { - /// The response future for [`IntoService`]. - pub type IntoServiceFuture = - future::BoxFuture<'static, Result, Infallible>>; -} - -/// A handler [`Service`] that accepts requests based on a [`MethodFilter`]. +/// A handler [`Service`] that accepts requests based on a [`MethodFilter`] and +/// allows chaining additional handlers. #[derive(Debug, Clone, Copy)] pub struct OnMethod { pub(crate) method: MethodFilter, @@ -652,10 +669,10 @@ where fn call(&mut self, req: Request) -> Self::Future { let f = if self.method.matches(req.method()) { let response_future = self.svc.clone().oneshot(req); - future::Either::Left(BoxResponseBody(response_future)) + Either::Left(BoxResponseBody(response_future)) } else { let response_future = self.fallback.clone().oneshot(req); - future::Either::Right(BoxResponseBody(response_future)) + Either::Right(BoxResponseBody(response_future)) }; RouteFuture(f) } diff --git a/src/lib.rs b/src/lib.rs index d7cd93bb..bab2aeba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ //! - Solid foundation. tower-web is built on top of tower and makes it easy to //! plug in any middleware from the [tower] and [tower-http] ecosystem. //! - Focus on routing, extracting data from requests, and generating responses. -//! tower middleware can handle the rest. +//! Tower middleware can handle the rest. //! - Macro free core. Macro frameworks have their place but tower-web focuses //! on providing a core that is macro free. //! @@ -67,7 +67,8 @@ //! //! # Responses //! -//! Anything that implements [`IntoResponse`] can be returned from a handler: +//! Anything that implements [`IntoResponse`](response::IntoResponse) can be +//! returned from a handler: //! //! ```rust,no_run //! use tower_web::{body::Body, response::{Html, Json}, prelude::*}; @@ -288,8 +289,8 @@ //! implementations. //! //! For handlers created from async functions this is works automatically since -//! handlers must return something that implements [`IntoResponse`], even if its -//! a `Result`. +//! handlers must return something that implements +//! [`IntoResponse`](response::IntoResponse), even if its a `Result`. //! //! However middleware might add new failure cases that has to be handled. For //! that tower-web provides a `handle_error` combinator: @@ -447,9 +448,8 @@ //! //! ```rust,no_run //! use tower_web::{ -//! service, prelude::*, //! // `ServiceExt` adds `handle_error` to any `Service` -//! ServiceExt, +//! service::{self, ServiceExt}, prelude::*, //! }; //! use tower_http::services::ServeFile; //! use http::Response; @@ -503,7 +503,7 @@ //! `nest` can also be used to serve static files from a directory: //! //! ```rust,no_run -//! use tower_web::{prelude::*, ServiceExt, routing::nest}; +//! use tower_web::{prelude::*, service::ServiceExt, routing::nest}; //! use tower_http::services::ServeDir; //! use http::Response; //! use std::convert::Infallible; @@ -571,12 +571,10 @@ #![cfg_attr(test, allow(clippy::float_cmp))] use self::body::Body; -use bytes::Bytes; -use http::{Request, Response}; -use response::IntoResponse; +use http::Request; use routing::{EmptyRouter, Route}; use std::convert::Infallible; -use tower::{BoxError, Service}; +use tower::Service; #[macro_use] pub(crate) mod macros; @@ -624,8 +622,8 @@ pub mod prelude { /// Note that `service`'s error type must be [`Infallible`] meaning you must /// handle all errors. If you're creating handlers from async functions that is /// handled automatically but if you're routing to some other [`Service`] you -/// might need to use [`handle_error`](ServiceExt::handle_error) to map errors -/// into responses. +/// might need to use [`handle_error`](service::ServiceExt::handle_error) to map +/// errors into responses. /// /// # Examples /// @@ -655,59 +653,6 @@ where routing::EmptyRouter.route(description, service) } -/// Extension trait that adds additional methods to [`Service`]. -pub trait ServiceExt: Service, Response = Response> { - /// Handle errors from a service. - /// - /// tower-web requires all handles to never return errors. If you route to - /// [`Service`], not created by tower-web, who's error isn't `Infallible` - /// you can use this combinator to handle the error. - /// - /// `handle_error` takes a closure that will map errors from the service - /// into responses. The closure's return type must implement - /// [`IntoResponse`]. - /// - /// # Example - /// - /// ```rust,no_run - /// use tower_web::{ - /// service, prelude::*, - /// ServiceExt, - /// }; - /// use http::Response; - /// use tower::{service_fn, BoxError}; - /// - /// // A service that might fail with `std::io::Error` - /// let service = service_fn(|_: Request| async { - /// let res = Response::new(Body::empty()); - /// Ok::<_, std::io::Error>(res) - /// }); - /// - /// let app = route( - /// "/", - /// service.handle_error(|error: std::io::Error| { - /// // Handle error by returning something that implements `IntoResponse` - /// }), - /// ); - /// # - /// # async { - /// # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; - /// # }; - /// ``` - fn handle_error(self, f: F) -> service::HandleError - where - Self: Sized, - F: FnOnce(Self::Error) -> Res, - Res: IntoResponse, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, - { - service::HandleError::new(self, f) - } -} - -impl ServiceExt for S where S: Service, Response = Response> {} - pub(crate) trait ResultExt { fn unwrap_infallible(self) -> T; } @@ -720,3 +665,9 @@ impl ResultExt for Result { } } } + +mod sealed { + #![allow(unreachable_pub, missing_docs)] + + pub trait Sealed {} +} diff --git a/src/routing.rs b/src/routing.rs index bccfe978..7927b6bb 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -12,6 +12,7 @@ use regex::Regex; use std::{ borrow::Cow, convert::Infallible, + fmt, future::Future, pin::Pin, sync::Arc, @@ -78,7 +79,27 @@ pub struct Route { pub(crate) fallback: F, } -pub trait RoutingDsl: Sized { +/// Trait for building routers. +// TODO(david): this name isn't great +pub trait RoutingDsl: crate::sealed::Sealed + Sized { + /// Add another route to the router. + /// + /// # Example + /// + /// ```rust + /// use tower_web::prelude::*; + /// + /// async fn first_handler(request: Request) { /* ... */ } + /// + /// async fn second_handler(request: Request) { /* ... */ } + /// + /// async fn third_handler(request: Request) { /* ... */ } + /// + /// // `GET /` goes to `first_handler`, `POST /` goes to `second_handler`, + /// // and `GET /foo` goes to third_handler. + /// let app = route("/", get(first_handler).post(second_handler)) + /// .route("/foo", get(third_handler)); + /// ``` fn route(self, description: &str, svc: T) -> Route where T: Service, Error = Infallible> + Clone, @@ -90,6 +111,9 @@ pub trait RoutingDsl: Sized { } } + /// Nest another service inside this router at the given path. + /// + /// See [`nest`] for more details. fn nest(self, description: &str, svc: T) -> Nested where T: Service, Error = Infallible> + Clone, @@ -101,6 +125,29 @@ pub trait RoutingDsl: Sized { } } + /// Create a boxed route trait object. + /// + /// This makes it easier to name the types of routers to, for example, + /// return them from functions: + /// + /// ```rust + /// use tower_web::{body::BoxBody, routing::BoxRoute, prelude::*}; + /// + /// async fn first_handler(request: Request) { /* ... */ } + /// + /// async fn second_handler(request: Request) { /* ... */ } + /// + /// async fn third_handler(request: Request) { /* ... */ } + /// + /// fn app() -> BoxRoute { + /// route("/", get(first_handler).post(second_handler)) + /// .route("/foo", get(third_handler)) + /// .boxed() + /// } + /// ``` + /// + /// It also helps with compile times when you have a very large number of + /// routes. fn boxed(self) -> BoxRoute where Self: Service, Response = Response, Error = Infallible> + Send + 'static, @@ -115,6 +162,66 @@ pub trait RoutingDsl: Sized { .service(self) } + /// Apply a [`tower::Layer`] to the router. + /// + /// All requests to the router will be processed by the layer's + /// corresponding middleware. + /// + /// This can be used to add additional processing to a request for a group + /// of routes. + /// + /// Note this differes from [`handler::Layered`](crate::handler::Layered) + /// which adds a middleware to a single handler. + /// + /// # Example + /// + /// Adding the [`tower::limit::ConcurrencyLimit`] middleware to a group of + /// routes can be done like so: + /// + /// ```rust + /// use tower_web::prelude::*; + /// use tower::limit::{ConcurrencyLimitLayer, ConcurrencyLimit}; + /// + /// async fn first_handler(request: Request) { /* ... */ } + /// + /// async fn second_handler(request: Request) { /* ... */ } + /// + /// async fn third_handler(request: Request) { /* ... */ } + /// + /// // All requests to `handler` and `other_handler` will be sent through + /// // `ConcurrencyLimit` + /// let app = route("/", get(first_handler)) + /// .route("/foo", get(second_handler)) + /// .layer(ConcurrencyLimitLayer::new(64)) + /// // Request to `GET /bar` will go directly to `third_handler` and + /// // wont be sent through `ConcurrencyLimit` + /// .route("/bar", get(third_handler)); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; + /// # }; + /// ``` + /// + /// This is commonly used to add middleware such as tracing/logging to your + /// entire app: + /// + /// ```rust + /// use tower_web::prelude::*; + /// use tower_http::trace::TraceLayer; + /// + /// async fn first_handler(request: Request) { /* ... */ } + /// + /// async fn second_handler(request: Request) { /* ... */ } + /// + /// async fn third_handler(request: Request) { /* ... */ } + /// + /// let app = route("/", get(first_handler)) + /// .route("/foo", get(second_handler)) + /// .route("/bar", get(third_handler)) + /// .layer(TraceLayer::new_for_http()); + /// ``` + /// + /// When adding middleware that might fail its required to handle those + /// errors. See [`Layered::handle_error`] for more details. fn layer(self, layer: L) -> Layered where L: Layer, @@ -126,6 +233,8 @@ pub trait RoutingDsl: Sized { impl RoutingDsl for Route {} +impl crate::sealed::Sealed for Route {} + impl Service> for Route where S: Service, Response = Response, Error = Infallible> + Clone, @@ -232,6 +341,8 @@ pub struct EmptyRouter; impl RoutingDsl for EmptyRouter {} +impl crate::sealed::Sealed for EmptyRouter {} + impl Service> for EmptyRouter { type Response = Response; type Error = Infallible; @@ -344,8 +455,17 @@ struct Match<'a> { type Captures = Vec<(String, String)>; +/// A boxed route trait object. +/// +/// See [`RoutingDsl::boxed`] for more details. pub struct BoxRoute(Buffer, Response, Infallible>, Request>); +impl fmt::Debug for BoxRoute { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BoxRoute").finish() + } +} + impl Clone for BoxRoute { fn clone(&self) -> Self { Self(self.0.clone()) @@ -354,6 +474,8 @@ impl Clone for BoxRoute { impl RoutingDsl for BoxRoute {} +impl crate::sealed::Sealed for BoxRoute {} + impl Service> for BoxRoute where B: http_body::Body + Send + Sync + 'static, @@ -383,6 +505,12 @@ type InnerFuture = Oneshot< Request, >; +impl fmt::Debug for BoxRouteFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BoxRouteFuture").finish() + } +} + impl Future for BoxRouteFuture where B: http_body::Body + Send + Sync + 'static, @@ -430,12 +558,55 @@ fn handle_buffer_error(error: BoxError) -> Response { .unwrap() } +/// A [`Service`] created from a router by applying a Tower middleware. +/// +/// Created with [`RoutingDsl::layer`]. See that method for more details. #[derive(Clone, Debug)] pub struct Layered(S); impl RoutingDsl for Layered {} +impl crate::sealed::Sealed for Layered {} + impl Layered { + /// Create a new [`Layered`] service where errors will be handled using the + /// given closure. + /// + /// tower-web requires that services gracefully handles all errors. That + /// means when you apply a Tower middleware that adds a new failure + /// condition you have to handle that as well. + /// + /// That can be done using `handle_error` like so: + /// + /// ```rust + /// use tower_web::prelude::*; + /// use http::StatusCode; + /// use tower::{BoxError, timeout::TimeoutLayer}; + /// use std::time::Duration; + /// + /// async fn handler(request: Request) { /* ... */ } + /// + /// // `Timeout` will fail with `BoxError` if the timeout elapses... + /// let layered_handler = route("/", get(handler)) + /// .layer(TimeoutLayer::new(Duration::from_secs(30))); + /// + /// // ...so we must handle that error + /// let layered_handler = layered_handler.handle_error(|error: BoxError| { + /// if error.is::() { + /// ( + /// StatusCode::REQUEST_TIMEOUT, + /// "request took too long".to_string(), + /// ) + /// } else { + /// ( + /// StatusCode::INTERNAL_SERVER_ERROR, + /// format!("Unhandled internal error: {}", error), + /// ) + /// } + /// }); + /// ``` + /// + /// The closure can return any type that implements [`IntoResponse`]. pub fn handle_error(self, f: F) -> crate::service::HandleError where S: Service, Response = Response> + Clone, @@ -520,7 +691,7 @@ where /// /// ``` /// use tower_web::{ -/// routing::nest, service::get, ServiceExt, prelude::*, +/// routing::nest, service::{get, ServiceExt}, prelude::*, /// }; /// use tower_http::services::ServeDir; /// @@ -560,6 +731,8 @@ pub struct Nested { impl RoutingDsl for Nested {} +impl crate::sealed::Sealed for Nested {} + impl Service> for Nested where S: Service, Response = Response, Error = Infallible> + Clone, diff --git a/src/service.rs b/src/service.rs deleted file mode 100644 index 6b41e0a1..00000000 --- a/src/service.rs +++ /dev/null @@ -1,274 +0,0 @@ -use crate::{ - body::{Body, BoxBody}, - response::IntoResponse, - routing::{BoxResponseBody, EmptyRouter, MethodFilter, RouteFuture}, -}; -use bytes::Bytes; -use futures_util::future; -use futures_util::ready; -use http::{Request, Response}; -use pin_project::pin_project; -use std::{ - convert::Infallible, - fmt, - future::Future, - pin::Pin, - task::{Context, Poll}, -}; -use tower::{util::Oneshot, BoxError, Service, ServiceExt as _}; - -pub fn any(svc: S) -> OnMethod { - on(MethodFilter::Any, svc) -} - -pub fn connect(svc: S) -> OnMethod { - on(MethodFilter::Connect, svc) -} - -pub fn delete(svc: S) -> OnMethod { - on(MethodFilter::Delete, svc) -} - -pub fn get(svc: S) -> OnMethod { - on(MethodFilter::Get, svc) -} - -pub fn head(svc: S) -> OnMethod { - on(MethodFilter::Head, svc) -} - -pub fn options(svc: S) -> OnMethod { - on(MethodFilter::Options, svc) -} - -pub fn patch(svc: S) -> OnMethod { - on(MethodFilter::Patch, svc) -} - -pub fn post(svc: S) -> OnMethod { - on(MethodFilter::Post, svc) -} - -pub fn put(svc: S) -> OnMethod { - on(MethodFilter::Put, svc) -} - -pub fn trace(svc: S) -> OnMethod { - on(MethodFilter::Trace, svc) -} - -pub fn on(method: MethodFilter, svc: S) -> OnMethod { - OnMethod { - method, - svc, - fallback: EmptyRouter, - } -} - -#[derive(Clone)] -pub struct OnMethod { - pub(crate) method: MethodFilter, - pub(crate) svc: S, - pub(crate) fallback: F, -} - -impl OnMethod { - pub fn any(self, svc: T) -> OnMethod - where - T: Service> + Clone, - { - self.on(MethodFilter::Any, svc) - } - - pub fn connect(self, svc: T) -> OnMethod - where - T: Service> + Clone, - { - self.on(MethodFilter::Connect, svc) - } - - pub fn delete(self, svc: T) -> OnMethod - where - T: Service> + Clone, - { - self.on(MethodFilter::Delete, svc) - } - - pub fn get(self, svc: T) -> OnMethod - where - T: Service> + Clone, - { - self.on(MethodFilter::Get, svc) - } - - pub fn head(self, svc: T) -> OnMethod - where - T: Service> + Clone, - { - self.on(MethodFilter::Head, svc) - } - - pub fn options(self, svc: T) -> OnMethod - where - T: Service> + Clone, - { - self.on(MethodFilter::Options, svc) - } - - pub fn patch(self, svc: T) -> OnMethod - where - T: Service> + Clone, - { - self.on(MethodFilter::Patch, svc) - } - - pub fn post(self, svc: T) -> OnMethod - where - T: Service> + Clone, - { - self.on(MethodFilter::Post, svc) - } - - pub fn put(self, svc: T) -> OnMethod - where - T: Service> + Clone, - { - self.on(MethodFilter::Put, svc) - } - - pub fn trace(self, svc: T) -> OnMethod - where - T: Service> + Clone, - { - self.on(MethodFilter::Trace, svc) - } - - pub fn on(self, method: MethodFilter, svc: T) -> OnMethod - where - T: Service> + Clone, - { - OnMethod { - method, - svc, - fallback: self, - } - } -} - -// this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean -// that up, but not sure its possible. -impl Service> for OnMethod -where - S: Service, Response = Response, Error = Infallible> + Clone, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, - - F: Service, Response = Response, Error = Infallible> + Clone, - FB: http_body::Body + Send + Sync + 'static, - FB::Error: Into, -{ - type Response = Response; - type Error = Infallible; - type Future = RouteFuture; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - let f = if self.method.matches(req.method()) { - let response_future = self.svc.clone().oneshot(req); - future::Either::Left(BoxResponseBody(response_future)) - } else { - let response_future = self.fallback.clone().oneshot(req); - future::Either::Right(BoxResponseBody(response_future)) - }; - RouteFuture(f) - } -} - -/// A [`Service`] adapter that handles errors with a closure. -/// -/// Create with [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or -/// [`routing::Layered::handle_error`](crate::routing::Layered::handle_error). See those methods -/// for more details. -#[derive(Clone)] -pub struct HandleError { - pub(crate) inner: S, - pub(crate) f: F, -} - -impl crate::routing::RoutingDsl for HandleError {} - -impl HandleError { - pub(crate) fn new(inner: S, f: F) -> Self { - Self { inner, f } - } -} - -impl fmt::Debug for HandleError -where - S: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("HandleError") - .field("inner", &self.inner) - .field("f", &format_args!("{}", std::any::type_name::())) - .finish() - } -} - -impl Service> for HandleError -where - S: Service, Response = Response> + Clone, - F: FnOnce(S::Error) -> Res + Clone, - Res: IntoResponse, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, -{ - type Response = Response; - type Error = Infallible; - type Future = HandleErrorFuture>, F>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - HandleErrorFuture { - f: Some(self.f.clone()), - inner: self.inner.clone().oneshot(req), - } - } -} - -#[pin_project] -pub struct HandleErrorFuture { - #[pin] - inner: Fut, - f: Option, -} - -impl Future for HandleErrorFuture -where - Fut: Future, E>>, - F: FnOnce(E) -> Res, - Res: IntoResponse, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, -{ - type Output = Result, Infallible>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - match ready!(this.inner.poll(cx)) { - Ok(res) => Ok(res.map(BoxBody::new)).into(), - Err(err) => { - let f = this.f.take().unwrap(); - let res = f(err).into_response(); - Ok(res.map(BoxBody::new)).into() - } - } - } -} diff --git a/src/service/future.rs b/src/service/future.rs new file mode 100644 index 00000000..2190666e --- /dev/null +++ b/src/service/future.rs @@ -0,0 +1,47 @@ +//! [`Service`](tower::Service) future types. + +use crate::{body::BoxBody, response::IntoResponse}; +use bytes::Bytes; +use futures_util::ready; +use http::Response; +use pin_project::pin_project; +use std::{ + convert::Infallible, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower::BoxError; + +/// Response future for [`HandleError`](super::HandleError). +#[pin_project] +#[derive(Debug)] +pub struct HandleErrorFuture { + #[pin] + pub(super) inner: Fut, + pub(super) f: Option, +} + +impl Future for HandleErrorFuture +where + Fut: Future, E>>, + F: FnOnce(E) -> Res, + Res: IntoResponse, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into + Send + Sync + 'static, +{ + type Output = Result, Infallible>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match ready!(this.inner.poll(cx)) { + Ok(res) => Ok(res.map(BoxBody::new)).into(), + Err(err) => { + let f = this.f.take().unwrap(); + let res = f(err).into_response(); + Ok(res.map(BoxBody::new)).into() + } + } + } +} diff --git a/src/service/mod.rs b/src/service/mod.rs new file mode 100644 index 00000000..891c11cc --- /dev/null +++ b/src/service/mod.rs @@ -0,0 +1,488 @@ +//! Use Tower [`Service`]s to handl requests. +//! +//! Most of the time applications will be written by composing +//! [handlers](crate::handler), however sometimes you might have some general +//! [`Service`] that you want to route requests to. That is enabled by the +//! functions in this module. +//! +//! # Example +//! +//! Using [`Redirect`] to redirect requests can be done like so: +//! +//! ``` +//! use tower_http::services::Redirect; +//! use tower_web::{service, handler, prelude::*}; +//! +//! async fn handler(request: Request) { /* ... */ } +//! +//! let redirect_service = Redirect::::permanent("/new".parse().unwrap()); +//! +//! let app = route("/old", service::get(redirect_service)) +//! .route("/new", handler::get(handler)); +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! [`Redirect`]: tower_http::services::Redirect + +use crate::{ + body::{Body, BoxBody}, + response::IntoResponse, + routing::{BoxResponseBody, EmptyRouter, MethodFilter, RouteFuture}, +}; +use bytes::Bytes; +use futures_util::future::Either; +use http::{Request, Response}; +use std::{ + convert::Infallible, + fmt, + task::{Context, Poll}, +}; +use tower::{util::Oneshot, BoxError, Service, ServiceExt as _}; + +pub mod future; + +/// Route `CONNECT` requests to the given service. +/// +/// See [`get`] for an example. +pub fn connect(svc: S) -> OnMethod +where + S: Service, Error = Infallible> + Clone, +{ + on(MethodFilter::Connect, svc) +} + +/// Route `DELETE` requests to the given service. +/// +/// See [`get`] for an example. +pub fn delete(svc: S) -> OnMethod +where + S: Service, Error = Infallible> + Clone, +{ + on(MethodFilter::Delete, svc) +} + +/// Route `GET` requests to the given service. +/// +/// # Example +/// +/// ```rust +/// use tower_web::{service, prelude::*}; +/// use http::Response; +/// use std::convert::Infallible; +/// use hyper::Body; +/// +/// let service = tower::service_fn(|request: Request| async { +/// Ok::<_, Infallible>(Response::new(Body::empty())) +/// }); +/// +/// // Requests to `GET /` will go to `service`. +/// let app = route("/", service::get(service)); +/// ``` +/// +/// You can only add services who cannot fail (their error type must be +/// [`Infallible`]). To gracefully handle errors see [`ServiceExt::handle_error`]. +pub fn get(svc: S) -> OnMethod +where + S: Service, Error = Infallible> + Clone, +{ + on(MethodFilter::Get, svc) +} + +/// Route `HEAD` requests to the given service. +/// +/// See [`get`] for an example. +pub fn head(svc: S) -> OnMethod +where + S: Service, Error = Infallible> + Clone, +{ + on(MethodFilter::Head, svc) +} + +/// Route `OPTIONS` requests to the given service. +/// +/// See [`get`] for an example. +pub fn options(svc: S) -> OnMethod +where + S: Service, Error = Infallible> + Clone, +{ + on(MethodFilter::Options, svc) +} + +/// Route `PATCH` requests to the given service. +/// +/// See [`get`] for an example. +pub fn patch(svc: S) -> OnMethod +where + S: Service, Error = Infallible> + Clone, +{ + on(MethodFilter::Patch, svc) +} + +/// Route `POST` requests to the given service. +/// +/// See [`get`] for an example. +pub fn post(svc: S) -> OnMethod +where + S: Service, Error = Infallible> + Clone, +{ + on(MethodFilter::Post, svc) +} + +/// Route `PUT` requests to the given service. +/// +/// See [`get`] for an example. +pub fn put(svc: S) -> OnMethod +where + S: Service, Error = Infallible> + Clone, +{ + on(MethodFilter::Put, svc) +} + +/// Route `TRACE` requests to the given service. +/// +/// See [`get`] for an example. +pub fn trace(svc: S) -> OnMethod +where + S: Service, Error = Infallible> + Clone, +{ + on(MethodFilter::Trace, svc) +} + +/// Route requests with the given method to the service. +/// +/// # Example +/// +/// ```rust +/// use tower_web::{handler::on, service, routing::MethodFilter, prelude::*}; +/// use http::Response; +/// use std::convert::Infallible; +/// use hyper::Body; +/// +/// let service = tower::service_fn(|request: Request| async { +/// Ok::<_, Infallible>(Response::new(Body::empty())) +/// }); +/// +/// // Requests to `POST /` will go to `service`. +/// let app = route("/", service::on(MethodFilter::Post, service)); +/// ``` +pub fn on(method: MethodFilter, svc: S) -> OnMethod +where + S: Service, Error = Infallible> + Clone, +{ + OnMethod { + method, + svc, + fallback: EmptyRouter, + } +} + +/// A [`Service`] that accepts requests based on a [`MethodFilter`] and allows +/// chaining additional services. +#[derive(Clone, Debug)] +pub struct OnMethod { + pub(crate) method: MethodFilter, + pub(crate) svc: S, + pub(crate) fallback: F, +} + +impl OnMethod { + /// Chain an additional service that will accept all requests regardless of + /// its HTTP method. + /// + /// See [`OnMethod::get`] for an example. + pub fn any(self, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + self.on(MethodFilter::Any, svc) + } + + /// Chain an additional service that will only accept `CONNECT` requests. + /// + /// See [`OnMethod::get`] for an example. + pub fn connect(self, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + self.on(MethodFilter::Connect, svc) + } + + /// Chain an additional service that will only accept `DELETE` requests. + /// + /// See [`OnMethod::get`] for an example. + pub fn delete(self, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + self.on(MethodFilter::Delete, svc) + } + + /// Chain an additional service that will only accept `GET` requests. + /// + /// # Example + /// + /// ```rust + /// use tower_web::{handler::on, service, routing::MethodFilter, prelude::*}; + /// use http::Response; + /// use std::convert::Infallible; + /// use hyper::Body; + /// + /// let service = tower::service_fn(|request: Request| async { + /// Ok::<_, Infallible>(Response::new(Body::empty())) + /// }); + /// + /// let other_service = tower::service_fn(|request: Request| async { + /// Ok::<_, Infallible>(Response::new(Body::empty())) + /// }); + /// + /// // Requests to `GET /` will go to `service` and `POST /` will go to + /// // `other_service`. + /// let app = route("/", service::post(service).get(other_service)); + /// ``` + /// + /// You can only add services who cannot fail (their error type must be + /// [`Infallible`]). To gracefully handle errors see + /// [`ServiceExt::handle_error`]. + pub fn get(self, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + self.on(MethodFilter::Get, svc) + } + + /// Chain an additional service that will only accept `HEAD` requests. + /// + /// See [`OnMethod::get`] for an example. + pub fn head(self, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + self.on(MethodFilter::Head, svc) + } + + /// Chain an additional service that will only accept `OPTIONS` requests. + /// + /// See [`OnMethod::get`] for an example. + pub fn options(self, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + self.on(MethodFilter::Options, svc) + } + + /// Chain an additional service that will only accept `PATCH` requests. + /// + /// See [`OnMethod::get`] for an example. + pub fn patch(self, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + self.on(MethodFilter::Patch, svc) + } + + /// Chain an additional service that will only accept `POST` requests. + /// + /// See [`OnMethod::get`] for an example. + pub fn post(self, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + self.on(MethodFilter::Post, svc) + } + + /// Chain an additional service that will only accept `PUT` requests. + /// + /// See [`OnMethod::get`] for an example. + pub fn put(self, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + self.on(MethodFilter::Put, svc) + } + + /// Chain an additional service that will only accept `TRACE` requests. + /// + /// See [`OnMethod::get`] for an example. + pub fn trace(self, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + self.on(MethodFilter::Trace, svc) + } + + /// Chain an additional service that will accept requests matching the given + /// `MethodFilter`. + /// + /// # Example + /// + /// ```rust + /// use tower_web::{handler::on, service, routing::MethodFilter, prelude::*}; + /// use http::Response; + /// use std::convert::Infallible; + /// use hyper::Body; + /// + /// let service = tower::service_fn(|request: Request| async { + /// Ok::<_, Infallible>(Response::new(Body::empty())) + /// }); + /// + /// let other_service = tower::service_fn(|request: Request| async { + /// Ok::<_, Infallible>(Response::new(Body::empty())) + /// }); + /// + /// // Requests to `DELETE /` will go to `service` + /// let app = route("/", service::on(MethodFilter::Delete, service)); + /// ``` + pub fn on(self, method: MethodFilter, svc: T) -> OnMethod + where + T: Service, Error = Infallible> + Clone, + { + OnMethod { + method, + svc, + fallback: self, + } + } +} + +// this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean +// that up, but not sure its possible. +impl Service> for OnMethod +where + S: Service, Response = Response, Error = Infallible> + Clone, + SB: http_body::Body + Send + Sync + 'static, + SB::Error: Into, + + F: Service, Response = Response, Error = Infallible> + Clone, + FB: http_body::Body + Send + Sync + 'static, + FB::Error: Into, +{ + type Response = Response; + type Error = Infallible; + type Future = RouteFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let f = if self.method.matches(req.method()) { + let response_future = self.svc.clone().oneshot(req); + Either::Left(BoxResponseBody(response_future)) + } else { + let response_future = self.fallback.clone().oneshot(req); + Either::Right(BoxResponseBody(response_future)) + }; + RouteFuture(f) + } +} + +/// A [`Service`] adapter that handles errors with a closure. +/// +/// Created with +/// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or +/// [`routing::Layered::handle_error`](crate::routing::Layered::handle_error). +/// See those methods for more details. +#[derive(Clone)] +pub struct HandleError { + pub(crate) inner: S, + pub(crate) f: F, +} + +impl crate::routing::RoutingDsl for HandleError {} + +impl crate::sealed::Sealed for HandleError {} + +impl HandleError { + pub(crate) fn new(inner: S, f: F) -> Self { + Self { inner, f } + } +} + +impl fmt::Debug for HandleError +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HandleError") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Service> for HandleError +where + S: Service, Response = Response> + Clone, + F: FnOnce(S::Error) -> Res + Clone, + Res: IntoResponse, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into + Send + Sync + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = future::HandleErrorFuture>, F>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + future::HandleErrorFuture { + f: Some(self.f.clone()), + inner: self.inner.clone().oneshot(req), + } + } +} + +/// Extension trait that adds additional methods to [`Service`]. +pub trait ServiceExt: Service, Response = Response> { + /// Handle errors from a service. + /// + /// tower-web requires all handlers and services, that are part of the + /// router, to never return errors. If you route to [`Service`], not created + /// by tower-web, who's error isn't `Infallible` you can use this combinator + /// to handle the error. + /// + /// `handle_error` takes a closure that will map errors from the service + /// into responses. The closure's return type must implement + /// [`IntoResponse`]. + /// + /// # Example + /// + /// ```rust,no_run + /// use tower_web::{service::{self, ServiceExt}, prelude::*}; + /// use http::Response; + /// use tower::{service_fn, BoxError}; + /// + /// // A service that might fail with `std::io::Error` + /// let service = service_fn(|_: Request| async { + /// let res = Response::new(Body::empty()); + /// Ok::<_, std::io::Error>(res) + /// }); + /// + /// let app = route( + /// "/", + /// service.handle_error(|error: std::io::Error| { + /// // Handle error by returning something that implements `IntoResponse` + /// }), + /// ); + /// # + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; + /// # }; + /// ``` + fn handle_error(self, f: F) -> HandleError + where + Self: Sized, + F: FnOnce(Self::Error) -> Res, + Res: IntoResponse, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into + Send + Sync + 'static, + { + HandleError::new(self, f) + } +} + +impl ServiceExt for S where S: Service, Response = Response> {} diff --git a/src/tests.rs b/src/tests.rs index d0b3207f..2380fcfb 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -305,16 +305,18 @@ async fn boxing() { #[tokio::test] async fn service_handlers() { - use crate::ServiceExt as _; - use std::convert::Infallible; + use crate::service::ServiceExt as _; use tower::service_fn; use tower_http::services::ServeFile; let app = route( "/echo", - service::post(service_fn(|req: Request| async move { - Ok::<_, Infallible>(Response::new(req.into_body())) - })), + service::post( + service_fn(|req: Request| async move { + Ok::<_, BoxError>(Response::new(req.into_body())) + }) + .handle_error(|_error: BoxError| StatusCode::INTERNAL_SERVER_ERROR), + ), ) .route( "/static/Cargo.toml",