diff --git a/Cargo.toml b/Cargo.toml index c22b22d8..803003dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ headers = { optional = true, version = "0.3" } askama = "0.10.5" bb8 = "0.7.0" bb8-postgres = "0.7.0" +futures = "0.3" hyper = { version = "0.14", features = ["full"] } reqwest = { version = "0.11", features = ["json", "stream"] } serde = { version = "1.0", features = ["derive"] } @@ -66,15 +67,7 @@ features = [ [dev-dependencies.tower-http] version = "0.1" -features = [ - "add-extension", - "auth", - "compression", - "compression-full", - "fs", - "redirect", - "trace", -] +features = ["full"] [package.metadata.docs.rs] all-features = true diff --git a/deny.toml b/deny.toml index 1e976877..0e588651 100644 --- a/deny.toml +++ b/deny.toml @@ -16,7 +16,12 @@ confidence-threshold = 0.8 multiple-versions = "deny" highlight = "all" skip-tree = [] -skip = [] +skip = [ + # iri-string uses old version + # iri-string pulled in by tower-http + # PR to update tower-http is https://github.com/tower-rs/tower-http/pull/110 + { name = "nom", version = "=5.1.2" }, +] [sources] unknown-registry = "warn" diff --git a/examples/bb8_connection_pool.rs b/examples/bb8_connection_pool.rs index 9ecfc999..be40c8af 100644 --- a/examples/bb8_connection_pool.rs +++ b/examples/bb8_connection_pool.rs @@ -21,7 +21,10 @@ async fn main() { // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); - app.serve(&addr).await.unwrap(); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); } type ConnectionPool = Pool>; diff --git a/examples/error_handling_and_dependency_injection.rs b/examples/error_handling_and_dependency_injection.rs index 9bbac60b..64d918c8 100644 --- a/examples/error_handling_and_dependency_injection.rs +++ b/examples/error_handling_and_dependency_injection.rs @@ -37,7 +37,10 @@ async fn main() { // Run our application let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); - app.serve(&addr).await.unwrap(); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); } /// Handler for `GET /users/:id`. diff --git a/examples/form.rs b/examples/form.rs index d4a70a8b..c9755cc2 100644 --- a/examples/form.rs +++ b/examples/form.rs @@ -1,5 +1,4 @@ use awebframework::prelude::*; -use http::Request; use serde::Deserialize; use std::net::SocketAddr; @@ -14,10 +13,13 @@ async fn main() { // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); - app.serve(&addr).await.unwrap(); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); } -async fn show_form(_req: Request) -> response::Html<&'static str> { +async fn show_form() -> response::Html<&'static str> { response::Html( r#" diff --git a/examples/hello_world.rs b/examples/hello_world.rs index 6483c044..07f9a13d 100644 --- a/examples/hello_world.rs +++ b/examples/hello_world.rs @@ -12,7 +12,10 @@ async fn main() { // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); - app.serve(&addr).await.unwrap(); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); } async fn handler() -> response::Html<&'static str> { diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 6dd16842..d05446eb 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -58,7 +58,10 @@ async fn main() { // Run our app with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); - app.serve(&addr).await.unwrap(); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); } type SharedState = Arc>; @@ -98,7 +101,7 @@ async fn list_keys(Extension(state): Extension) -> String { .join("\n") } -fn admin_routes() -> BoxRoute { +fn admin_routes() -> BoxRoute { async fn delete_all_keys(Extension(state): Extension) { state.write().unwrap().db.clear(); } diff --git a/examples/static_file_server.rs b/examples/static_file_server.rs index 8c8b63f8..fbb53934 100644 --- a/examples/static_file_server.rs +++ b/examples/static_file_server.rs @@ -20,5 +20,8 @@ async fn main() { let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); - app.serve(&addr).await.unwrap(); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); } diff --git a/examples/templates.rs b/examples/templates.rs index cc1c8744..99a2ce54 100644 --- a/examples/templates.rs +++ b/examples/templates.rs @@ -13,7 +13,10 @@ async fn main() { // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); - app.serve(&addr).await.unwrap(); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); } async fn greet(params: extract::UrlParamsMap) -> impl IntoResponse { diff --git a/examples/versioning.rs b/examples/versioning.rs index a80b714c..208926b7 100644 --- a/examples/versioning.rs +++ b/examples/versioning.rs @@ -14,7 +14,10 @@ async fn main() { // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); - app.serve(&addr).await.unwrap(); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); } async fn handler(version: Version) { @@ -29,10 +32,13 @@ enum Version { } #[async_trait] -impl FromRequest for Version { +impl FromRequest for Version +where + B: Send, +{ type Rejection = Response; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { let params = extract::UrlParamsMap::from_request(req) .await .map_err(IntoResponse::into_response)?; diff --git a/examples/websocket.rs b/examples/websocket.rs index 1e1526c5..c2f9c5b4 100644 --- a/examples/websocket.rs +++ b/examples/websocket.rs @@ -51,7 +51,10 @@ async fn main() { // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); - app.serve(&addr).await.unwrap(); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); } async fn handle_socket(mut socket: WebSocket) { diff --git a/src/body.rs b/src/body.rs index 9bcbdc56..51894c5a 100644 --- a/src/body.rs +++ b/src/body.rs @@ -27,7 +27,7 @@ impl BoxBody { pub fn new(body: B) -> Self where B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, + B::Error: Into, { Self { inner: Box::pin(body.map_err(|error| BoxStdError(error.into()))), diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 3b13392a..463c8863 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -25,7 +25,7 @@ //! //! let app = route("/users", post(create_user)); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -40,10 +40,13 @@ //! struct ExtractUserAgent(HeaderValue); //! //! #[async_trait] -//! impl FromRequest for ExtractUserAgent { +//! impl FromRequest for ExtractUserAgent +//! where +//! B: Send, +//! { //! type Rejection = (StatusCode, &'static str); //! -//! async fn from_request(req: &mut Request) -> Result { +//! async fn from_request(req: &mut Request) -> Result { //! if let Some(user_agent) = req.headers().get(USER_AGENT) { //! Ok(ExtractUserAgent(user_agent.clone())) //! } else { @@ -60,7 +63,7 @@ //! //! let app = route("/foo", get(handler)); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -85,7 +88,7 @@ //! //! let app = route("/foo", get(handler)); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -107,7 +110,7 @@ //! //! let app = route("/users", post(create_user)); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -142,7 +145,7 @@ //! //! let app = route("/users", post(create_user)); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -161,17 +164,30 @@ //! //! let app = route("/users", post(create_user)); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` -use crate::{body::Body, response::IntoResponse, util::ByteStr}; +use crate::{ + body::{BoxBody, BoxStdError}, + response::IntoResponse, + util::ByteStr, +}; use async_trait::async_trait; use bytes::{Buf, Bytes}; +use futures_util::stream::Stream; use http::{header, HeaderMap, Method, Request, Uri, Version}; +use http_body::Body; use rejection::*; use serde::de::DeserializeOwned; -use std::{collections::HashMap, convert::Infallible, mem, str::FromStr}; +use std::{ + collections::HashMap, + convert::Infallible, + mem, + pin::Pin, + str::FromStr, + task::{Context, Poll}, +}; pub mod rejection; @@ -179,35 +195,37 @@ pub mod rejection; /// /// See the [module docs](crate::extract) for more details. #[async_trait] -pub trait FromRequest: Sized { +pub trait FromRequest: Sized { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request(req: &mut Request) -> Result; + async fn from_request(req: &mut Request) -> Result; } #[async_trait] -impl FromRequest for Option +impl FromRequest for Option where - T: FromRequest, + T: FromRequest, + B: Send, { type Rejection = Infallible; - async fn from_request(req: &mut Request) -> Result, Self::Rejection> { + async fn from_request(req: &mut Request) -> Result, Self::Rejection> { Ok(T::from_request(req).await.ok()) } } #[async_trait] -impl FromRequest for Result +impl FromRequest for Result where - T: FromRequest, + T: FromRequest, + B: Send, { type Rejection = Infallible; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { Ok(T::from_request(req).await) } } @@ -237,6 +255,9 @@ where /// } /// /// let app = route("/list_things", get(list_things)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` /// /// If the query string cannot be parsed it will reject the request with a `400 @@ -245,13 +266,14 @@ where pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequest for Query where T: DeserializeOwned, + B: Send, { type Rejection = QueryRejection; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { let query = req.uri().query().ok_or(QueryStringMissing)?; let value = serde_urlencoded::from_str(query) .map_err(FailedToDeserializeQueryString::new::)?; @@ -283,6 +305,9 @@ where /// } /// /// let app = route("/sign_up", post(accept_form)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` /// /// Note that `Content-Type: multipart/form-data` requests are not supported. @@ -290,14 +315,17 @@ where pub struct Form(pub T); #[async_trait] -impl FromRequest for Form +impl FromRequest for Form where T: DeserializeOwned, + B: http_body::Body + Default + Send, + B::Data: Send, + B::Error: Into, { type Rejection = FormRejection; #[allow(warnings)] - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { if !has_content_type(&req, "application/x-www-form-urlencoded") { Err(InvalidFormContentType)?; } @@ -343,6 +371,9 @@ where /// } /// /// let app = route("/users", post(create_user)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` /// /// If the query string cannot be parsed it will reject the request with a `400 @@ -353,13 +384,16 @@ where pub struct Json(pub T); #[async_trait] -impl FromRequest for Json +impl FromRequest for Json where T: DeserializeOwned, + B: http_body::Body + Default + Send, + B::Data: Send, + B::Error: Into, { type Rejection = JsonRejection; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { use bytes::Buf; if has_content_type(req, "application/json") { @@ -419,6 +453,9 @@ fn has_content_type(req: &Request, expected_content_type: &str) -> bool { /// // Add middleware that inserts the state into all incoming request's /// // extensions. /// .layer(AddExtensionLayer::new(state)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` /// /// If the extension is missing it will reject the request with a `500 Interal @@ -427,13 +464,14 @@ fn has_content_type(req: &Request, expected_content_type: &str) -> bool { pub struct Extension(pub T); #[async_trait] -impl FromRequest for Extension +impl FromRequest for Extension where T: Clone + Send + Sync + 'static, + B: Send, { type Rejection = MissingExtension; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { let value = req .extensions() .get::() @@ -445,10 +483,15 @@ where } #[async_trait] -impl FromRequest for Bytes { +impl FromRequest for Bytes +where + B: http_body::Body + Default + Send, + B::Data: Send, + B::Error: Into, +{ type Rejection = BytesRejection; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { let body = take_body(req)?; let bytes = hyper::body::to_bytes(body) @@ -460,10 +503,15 @@ impl FromRequest for Bytes { } #[async_trait] -impl FromRequest for String { +impl FromRequest for String +where + B: http_body::Body + Default + Send, + B::Data: Send, + B::Error: Into, +{ type Rejection = StringRejection; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { let body = take_body(req)?; let bytes = hyper::body::to_bytes(body) @@ -477,20 +525,60 @@ impl FromRequest for String { } } -#[async_trait] -impl FromRequest for Body { - type Rejection = BodyAlreadyExtracted; +/// Extractor that extracts the request body as a [`Stream`]. +/// +/// # Example +/// +/// ```rust,no_run +/// use awebframework::prelude::*; +/// use futures::StreamExt; +/// +/// async fn handler(mut stream: extract::BodyStream) { +/// while let Some(chunk) = stream.next().await { +/// // ... +/// } +/// } +/// +/// let app = route("/users", get(handler)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +#[derive(Debug)] +pub struct BodyStream(BoxBody); - async fn from_request(req: &mut Request) -> Result { - take_body(req) +impl Stream for BodyStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_data(cx) } } #[async_trait] -impl FromRequest for Request { +impl FromRequest for BodyStream +where + B: http_body::Body + Default + Send + Sync + 'static, + B::Data: Send, + B::Error: Into, +{ + type Rejection = BodyAlreadyExtracted; + + async fn from_request(req: &mut Request) -> Result { + let body = take_body(req)?; + let stream = BodyStream(BoxBody::new(body)); + Ok(stream) + } +} + +#[async_trait] +impl FromRequest for Request +where + B: Default + Send, +{ type Rejection = RequestAlreadyExtracted; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { struct RequestAlreadyExtractedExt; if req @@ -506,37 +594,49 @@ impl FromRequest for Request { } #[async_trait] -impl FromRequest for Method { +impl FromRequest for Method +where + B: Send, +{ type Rejection = Infallible; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { Ok(req.method().clone()) } } #[async_trait] -impl FromRequest for Uri { +impl FromRequest for Uri +where + B: Send, +{ type Rejection = Infallible; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { Ok(req.uri().clone()) } } #[async_trait] -impl FromRequest for Version { +impl FromRequest for Version +where + B: Send, +{ type Rejection = Infallible; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { Ok(req.version()) } } #[async_trait] -impl FromRequest for HeaderMap { +impl FromRequest for HeaderMap +where + B: Send, +{ type Rejection = Infallible; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { Ok(mem::take(req.headers_mut())) } } @@ -553,6 +653,9 @@ impl FromRequest for HeaderMap { /// } /// /// let app = route("/", post(handler)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` /// /// This requires the request to have a `Content-Length` header. @@ -560,13 +663,14 @@ impl FromRequest for HeaderMap { pub struct ContentLengthLimit(pub T); #[async_trait] -impl FromRequest for ContentLengthLimit +impl FromRequest for ContentLengthLimit where - T: FromRequest, + T: FromRequest, + B: Send, { type Rejection = ContentLengthLimitRejection; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned(); let content_length = @@ -604,6 +708,9 @@ where /// } /// /// let app = route("/users/:id", get(users_show)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` /// /// Note that you can only have one URL params extractor per handler. If you @@ -627,10 +734,13 @@ impl UrlParamsMap { } #[async_trait] -impl FromRequest for UrlParamsMap { +impl FromRequest for UrlParamsMap +where + B: Send, +{ type Rejection = MissingRouteParams; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { if let Some(params) = req .extensions_mut() .get_mut::>() @@ -664,6 +774,9 @@ impl FromRequest for UrlParamsMap { /// } /// /// let app = route("/users/:user_id/team/:team_id", get(users_teams_show)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` /// /// Note that you can only have one URL params extractor per handler. If you @@ -676,15 +789,16 @@ macro_rules! impl_parse_url { ( $head:ident, $($tail:ident),* $(,)? ) => { #[async_trait] - impl<$head, $($tail,)*> FromRequest for UrlParams<($head, $($tail,)*)> + impl FromRequest for UrlParams<($head, $($tail,)*)> where $head: FromStr + Send, $( $tail: FromStr + Send, )* + B: Send, { type Rejection = UrlParamsRejection; #[allow(non_snake_case)] - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { let params = if let Some(params) = req .extensions_mut() .get_mut::>() @@ -726,7 +840,10 @@ macro_rules! impl_parse_url { impl_parse_url!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); -fn take_body(req: &mut Request) -> Result { +fn take_body(req: &mut Request) -> Result +where + B: Default, +{ struct BodyAlreadyExtractedExt; if req @@ -740,33 +857,6 @@ fn take_body(req: &mut Request) -> Result { } } -macro_rules! impl_from_request_tuple { - () => {}; - - ( $head:ident, $($tail:ident),* $(,)? ) => { - #[allow(non_snake_case)] - #[async_trait] - impl FromRequest for ($head, $($tail,)*) - where - R: IntoResponse, - $head: FromRequest + Send, - $( $tail: FromRequest + Send, )* - { - type Rejection = R; - - async fn from_request(req: &mut Request) -> Result { - let $head = FromRequest::from_request(req).await?; - $( let $tail = FromRequest::from_request(req).await?; )* - Ok(($head, $($tail,)*)) - } - } - - impl_from_request_tuple!($($tail,)*); - }; -} - -impl_from_request_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); - /// Extractor that extracts a typed header value from [`headers`]. /// /// # Example @@ -782,6 +872,9 @@ impl_from_request_tuple!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, /// } /// /// let app = route("/users/:user_id/team/:team_id", get(users_teams_show)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` #[cfg(feature = "headers")] #[cfg_attr(docsrs, doc(cfg(feature = "headers")))] @@ -791,13 +884,14 @@ pub struct TypedHeader(pub T); #[cfg(feature = "headers")] #[cfg_attr(docsrs, doc(cfg(feature = "headers")))] #[async_trait] -impl FromRequest for TypedHeader +impl FromRequest for TypedHeader where T: headers::Header, + B: Send, { type Rejection = rejection::TypedHeaderRejection; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut Request) -> Result { let header_values = req.headers().get_all(T::name()); T::decode(&mut header_values.iter()) .map(Self) diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index edd42913..d900c9ba 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -138,7 +138,7 @@ define_rejection! { define_rejection! { #[status = INTERNAL_SERVER_ERROR] - #[body = "Cannot have two `Request` extractors for a single handler"] + #[body = "Cannot have two `Request<_>` extractors for a single handler"] /// Rejection type used if you try and extract the request more than once. pub struct RequestAlreadyExtracted; } diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 6999e00d..e1ea5d41 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -39,7 +39,7 @@ //! the [`extract`](crate::extract) module. use crate::{ - body::{Body, BoxBody}, + body::BoxBody, extract::FromRequest, response::IntoResponse, routing::{EmptyRouter, MethodFilter, RouteFuture}, @@ -71,10 +71,13 @@ pub mod future; /// /// // All requests to `/` will go to `handler` regardless of the HTTP method. /// let app = route("/", any(handler)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` -pub fn any(handler: H) -> OnMethod, EmptyRouter> +pub fn any(handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { on(MethodFilter::Any, handler) } @@ -82,9 +85,9 @@ where /// Route `CONNECT` requests to the given handler. /// /// See [`get`] for an example. -pub fn connect(handler: H) -> OnMethod, EmptyRouter> +pub fn connect(handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { on(MethodFilter::Connect, handler) } @@ -92,9 +95,9 @@ where /// Route `DELETE` requests to the given handler. /// /// See [`get`] for an example. -pub fn delete(handler: H) -> OnMethod, EmptyRouter> +pub fn delete(handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { on(MethodFilter::Delete, handler) } @@ -110,10 +113,13 @@ where /// /// // Requests to `GET /` will go to `handler`. /// let app = route("/", get(handler)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` -pub fn get(handler: H) -> OnMethod, EmptyRouter> +pub fn get(handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { on(MethodFilter::Get, handler) } @@ -121,9 +127,9 @@ where /// Route `HEAD` requests to the given handler. /// /// See [`get`] for an example. -pub fn head(handler: H) -> OnMethod, EmptyRouter> +pub fn head(handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { on(MethodFilter::Head, handler) } @@ -131,9 +137,9 @@ where /// Route `OPTIONS` requests to the given handler. /// /// See [`get`] for an example. -pub fn options(handler: H) -> OnMethod, EmptyRouter> +pub fn options(handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { on(MethodFilter::Options, handler) } @@ -141,9 +147,9 @@ where /// Route `PATCH` requests to the given handler. /// /// See [`get`] for an example. -pub fn patch(handler: H) -> OnMethod, EmptyRouter> +pub fn patch(handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { on(MethodFilter::Patch, handler) } @@ -151,9 +157,9 @@ where /// Route `POST` requests to the given handler. /// /// See [`get`] for an example. -pub fn post(handler: H) -> OnMethod, EmptyRouter> +pub fn post(handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { on(MethodFilter::Post, handler) } @@ -161,9 +167,9 @@ where /// Route `PUT` requests to the given handler. /// /// See [`get`] for an example. -pub fn put(handler: H) -> OnMethod, EmptyRouter> +pub fn put(handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { on(MethodFilter::Put, handler) } @@ -171,9 +177,9 @@ where /// Route `TRACE` requests to the given handler. /// /// See [`get`] for an example. -pub fn trace(handler: H) -> OnMethod, EmptyRouter> +pub fn trace(handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { on(MethodFilter::Trace, handler) } @@ -189,10 +195,13 @@ where /// /// // Requests to `POST /` will go to `handler`. /// let app = route("/", on(MethodFilter::Post, handler)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` -pub fn on(method: MethodFilter, handler: H) -> OnMethod, EmptyRouter> +pub fn on(method: MethodFilter, handler: H) -> OnMethod, EmptyRouter> where - H: Handler, + H: Handler, { OnMethod { method, @@ -216,14 +225,14 @@ mod sealed { /// /// See the [module docs](crate::handler) for more details. #[async_trait] -pub trait Handler: Sized { +pub trait Handler: Sized { // This seals the trait. We cannot use the regular "sealed super trait" approach // due to coherence. #[doc(hidden)] type Sealed: sealed::HiddentTrait; /// Call the handler with the given request. - async fn call(self, req: Request) -> Response; + async fn call(self, req: Request) -> Response; /// Apply a [`tower::Layer`] to the handler. /// @@ -248,33 +257,38 @@ pub trait Handler: Sized { /// async fn handler() { /* ... */ } /// /// let layered_handler = handler.layer(ConcurrencyLimitLayer::new(64)); + /// let app = route("/", get(layered_handler)); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; /// ``` /// /// When adding middleware that might fail its required to handle those /// errors. See [`Layered::handle_error`] for more details. fn layer(self, layer: L) -> Layered where - L: Layer>, + L: Layer>, { Layered::new(layer.layer(IntoService::new(self))) } /// Convert the handler into a [`Service`]. - fn into_service(self) -> IntoService { + fn into_service(self) -> IntoService { IntoService::new(self) } } #[async_trait] -impl Handler<()> for F +impl Handler for F where F: FnOnce() -> Fut + Send + Sync, Fut: Future + Send, Res: IntoResponse, + B: Send + 'static, { type Sealed = sealed::Hidden; - async fn call(self, _req: Request) -> Response { + async fn call(self, _req: Request) -> Response { self().await.into_response().map(BoxBody::new) } } @@ -285,17 +299,18 @@ macro_rules! impl_handler { ( $head:ident, $($tail:ident),* $(,)? ) => { #[async_trait] #[allow(non_snake_case)] - impl Handler<($head, $($tail,)*)> for F + impl Handler for F where F: FnOnce($head, $($tail,)*) -> Fut + Send + Sync, Fut: Future + Send, + B: Send + 'static, Res: IntoResponse, - $head: FromRequest + Send, - $( $tail: FromRequest + Send, )* + $head: FromRequest + Send, + $( $tail: FromRequest + Send, )* { type Sealed = sealed::Hidden; - async fn call(self, mut req: Request) -> Response { + async fn call(self, mut req: Request) -> Response { let $head = match $head::from_request(&mut req).await { Ok(value) => value, Err(rejection) => return rejection.into_response().map(BoxBody::new), @@ -347,17 +362,18 @@ where } #[async_trait] -impl Handler for Layered +impl Handler for Layered where - S: Service, Response = Response> + Send, + S: Service, Response = Response> + Send, S::Error: IntoResponse, S::Future: Send, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, + ReqBody: Send + 'static, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into + Send + Sync + 'static, { type Sealed = sealed::Hidden; - async fn call(self, req: Request) -> Response { + async fn call(self, req: Request) -> Response { match self .svc .oneshot(req) @@ -413,12 +429,20 @@ impl Layered { /// ) /// } /// }); + /// + /// let app = route("/", get(layered_handler)); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; /// ``` /// /// The closure can return any type that implements [`IntoResponse`]. - pub fn handle_error(self, f: F) -> Layered, T> + pub fn handle_error( + self, + f: F, + ) -> Layered, T> where - S: Service, Response = Response>, + S: Service, Response = Response>, F: FnOnce(S::Error) -> Res, Res: IntoResponse, { @@ -430,12 +454,12 @@ impl Layered { /// An adapter that makes a [`Handler`] into a [`Service`]. /// /// Created with [`Handler::into_service`]. -pub struct IntoService { +pub struct IntoService { handler: H, - _marker: PhantomData T>, + _marker: PhantomData (B, T)>, } -impl IntoService { +impl IntoService { fn new(handler: H) -> Self { Self { handler, @@ -444,7 +468,7 @@ impl IntoService { } } -impl fmt::Debug for IntoService +impl fmt::Debug for IntoService where H: fmt::Debug, { @@ -455,7 +479,7 @@ where } } -impl Clone for IntoService +impl Clone for IntoService where H: Clone, { @@ -467,9 +491,10 @@ where } } -impl Service> for IntoService +impl Service> for IntoService where - H: Handler + Clone + Send + 'static, + H: Handler + Clone + Send + 'static, + B: Send + 'static, { type Response = Response; type Error = Infallible; @@ -482,7 +507,7 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let handler = self.handler.clone(); let future = Box::pin(async move { let res = Handler::call(handler, req).await; @@ -506,9 +531,9 @@ impl OnMethod { /// its HTTP method. /// /// See [`OnMethod::get`] for an example. - pub fn any(self, handler: H) -> OnMethod, Self> + pub fn any(self, handler: H) -> OnMethod, Self> where - H: Handler, + H: Handler, { self.on(MethodFilter::Any, handler) } @@ -516,9 +541,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `CONNECT` requests. /// /// See [`OnMethod::get`] for an example. - pub fn connect(self, handler: H) -> OnMethod, Self> + pub fn connect(self, handler: H) -> OnMethod, Self> where - H: Handler, + H: Handler, { self.on(MethodFilter::Connect, handler) } @@ -526,9 +551,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `DELETE` requests. /// /// See [`OnMethod::get`] for an example. - pub fn delete(self, handler: H) -> OnMethod, Self> + pub fn delete(self, handler: H) -> OnMethod, Self> where - H: Handler, + H: Handler, { self.on(MethodFilter::Delete, handler) } @@ -547,10 +572,13 @@ impl OnMethod { /// // Requests to `GET /` will go to `handler` and `POST /` will go to /// // `other_handler`. /// let app = route("/", post(handler).get(other_handler)); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; /// ``` - pub fn get(self, handler: H) -> OnMethod, Self> + pub fn get(self, handler: H) -> OnMethod, Self> where - H: Handler, + H: Handler, { self.on(MethodFilter::Get, handler) } @@ -558,9 +586,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `HEAD` requests. /// /// See [`OnMethod::get`] for an example. - pub fn head(self, handler: H) -> OnMethod, Self> + pub fn head(self, handler: H) -> OnMethod, Self> where - H: Handler, + H: Handler, { self.on(MethodFilter::Head, handler) } @@ -568,9 +596,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `OPTIONS` requests. /// /// See [`OnMethod::get`] for an example. - pub fn options(self, handler: H) -> OnMethod, Self> + pub fn options(self, handler: H) -> OnMethod, Self> where - H: Handler, + H: Handler, { self.on(MethodFilter::Options, handler) } @@ -578,9 +606,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `PATCH` requests. /// /// See [`OnMethod::get`] for an example. - pub fn patch(self, handler: H) -> OnMethod, Self> + pub fn patch(self, handler: H) -> OnMethod, Self> where - H: Handler, + H: Handler, { self.on(MethodFilter::Patch, handler) } @@ -588,9 +616,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `POST` requests. /// /// See [`OnMethod::get`] for an example. - pub fn post(self, handler: H) -> OnMethod, Self> + pub fn post(self, handler: H) -> OnMethod, Self> where - H: Handler, + H: Handler, { self.on(MethodFilter::Post, handler) } @@ -598,9 +626,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `PUT` requests. /// /// See [`OnMethod::get`] for an example. - pub fn put(self, handler: H) -> OnMethod, Self> + pub fn put(self, handler: H) -> OnMethod, Self> where - H: Handler, + H: Handler, { self.on(MethodFilter::Put, handler) } @@ -608,9 +636,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `TRACE` requests. /// /// See [`OnMethod::get`] for an example. - pub fn trace(self, handler: H) -> OnMethod, Self> + pub fn trace(self, handler: H) -> OnMethod, Self> where - H: Handler, + H: Handler, { self.on(MethodFilter::Trace, handler) } @@ -630,10 +658,17 @@ impl OnMethod { /// // Requests to `GET /` will go to `handler` and `DELETE /` will go to /// // `other_handler` /// let app = route("/", get(handler).on(MethodFilter::Delete, other_handler)); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; /// ``` - pub fn on(self, method: MethodFilter, handler: H) -> OnMethod, Self> + pub fn on( + self, + method: MethodFilter, + handler: H, + ) -> OnMethod, Self> where - H: Handler, + H: Handler, { OnMethod { method, @@ -643,20 +678,20 @@ impl OnMethod { } } -impl Service> for OnMethod +impl Service> for OnMethod where - S: Service, Response = Response, Error = Infallible> + Clone, - F: Service, Response = Response, Error = Infallible> + Clone, + S: Service, Response = Response, Error = Infallible> + Clone, + F: Service, Response = Response, Error = Infallible> + Clone, { type Response = Response; type Error = Infallible; - type Future = RouteFuture; + type Future = RouteFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { if self.method.matches(req.method()) { let fut = self.svc.clone().oneshot(req); RouteFuture::a(fut) diff --git a/src/lib.rs b/src/lib.rs index 3ad89f1e..e3d4670e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,8 +33,8 @@ //! let app = route("/", get(|| async { "Hello, World!" })); //! //! // run it with hyper on localhost:3000 -//! app -//! .serve(&"0.0.0.0:3000".parse().unwrap()) +//! hyper::Server::bind(&"0.0.0.0:3000".parse().unwrap()) +//! .serve(app.into_make_service()) //! .await //! .unwrap(); //! } @@ -62,7 +62,7 @@ //! // `GET /foo` called //! } //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -141,7 +141,7 @@ //! .route("/result", get(result)) //! .route("/response", get(response)); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -174,7 +174,7 @@ //! // ... //! } //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -194,7 +194,7 @@ //! // ... //! } //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -232,7 +232,7 @@ //! // ... //! } //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -247,7 +247,7 @@ //! // ... //! } //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -278,7 +278,7 @@ //! //! async fn handler() {} //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -298,7 +298,7 @@ //! //! async fn post_foo() {} //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -348,7 +348,7 @@ //! //! async fn handle() {} //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -374,7 +374,7 @@ //! //! async fn other_handle() {} //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -427,7 +427,7 @@ //! ); //! }); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -458,7 +458,7 @@ //! // ... //! } //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -493,7 +493,7 @@ //! ) //! ); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -505,19 +505,19 @@ //! Applications can be nested by calling [`nest`](routing::nest): //! //! ```rust,no_run -//! use awebframework::{prelude::*, routing::BoxRoute, body::BoxBody}; +//! use awebframework::{prelude::*, routing::BoxRoute, body::{Body, BoxBody}}; //! use tower_http::services::ServeFile; //! use http::Response; //! use std::convert::Infallible; //! -//! fn api_routes() -> BoxRoute { +//! fn api_routes() -> BoxRoute { //! route("/users", get(|_: Request| async { /* ... */ })).boxed() //! } //! //! let app = route("/", get(|_: Request| async { /* ... */ })) //! .nest("/api", api_routes()); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -537,7 +537,7 @@ //! }) //! ); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -686,9 +686,9 @@ pub mod prelude { /// # Panics /// /// Panics if `description` doesn't start with `/`. -pub fn route(description: &str, service: S) -> Route +pub fn route(description: &str, service: S) -> Route where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { use routing::RoutingDsl; diff --git a/src/routing.rs b/src/routing.rs index 0d1c5e09..288b8500 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -6,7 +6,6 @@ use bytes::Bytes; use futures_util::{future, ready}; use http::{Method, Request, Response, StatusCode, Uri}; use http_body::Full; -use hyper::Body; use pin_project::pin_project; use regex::Regex; use std::{ @@ -101,10 +100,13 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// // and `GET /foo` goes to third_handler. /// let app = route("/", get(first_handler).post(second_handler)) /// .route("/foo", get(third_handler)); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; /// ``` - fn route(self, description: &str, svc: T) -> Route + fn route(self, description: &str, svc: T) -> Route where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { Route { pattern: PathPattern::new(description), @@ -116,9 +118,9 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// Nest another service inside this router at the given path. /// /// See [`nest`] for more details. - fn nest(self, description: &str, svc: T) -> Nested + fn nest(self, description: &str, svc: T) -> Nested where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { Nested { pattern: PathPattern::new(description), @@ -133,7 +135,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// return them from functions: /// /// ```rust - /// use awebframework::{routing::BoxRoute, prelude::*}; + /// use awebframework::{routing::BoxRoute, body::Body, prelude::*}; /// /// async fn first_handler() { /* ... */ } /// @@ -141,7 +143,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// /// async fn third_handler() { /* ... */ } /// - /// fn app() -> BoxRoute { + /// fn app() -> BoxRoute { /// route("/", get(first_handler).post(second_handler)) /// .route("/foo", get(third_handler)) /// .boxed() @@ -150,12 +152,16 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// /// It also helps with compile times when you have a very large number of /// routes. - fn boxed(self) -> BoxRoute + fn boxed(self) -> BoxRoute where - Self: Service, Response = Response, Error = Infallible> + Send + 'static, - >>::Future: Send, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, + Self: Service, Response = Response, Error = Infallible> + + Send + + 'static, + >>::Future: Send, + ReqBody: http_body::Body + Send + Sync + 'static, + ReqBody::Error: Into + Send + Sync + 'static, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into + Send + Sync + 'static, { ServiceBuilder::new() .layer_fn(BoxRoute) @@ -200,7 +206,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// // wont be sent through `ConcurrencyLimit` /// .route("/bar", get(third_handler)); /// # async { - /// # app.serve(&"".parse().unwrap()).await.unwrap(); + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// @@ -221,6 +227,9 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// .route("/foo", get(second_handler)) /// .route("/bar", get(third_handler)) /// .layer(TraceLayer::new_for_http()); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; /// ``` /// /// When adding middleware that might fail its required to handle those @@ -228,9 +237,8 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { fn layer(self, layer: L) -> Layered where L: Layer, - L::Service: Service> + Clone, { - Layered(layer.layer(self)) + Layered::new(layer.layer(self)) } /// Convert this router into a [`MakeService`], that is a [`Service`] who's @@ -259,52 +267,26 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { { tower::make::Shared::new(self) } - - /// Serve this router with [hyper] on the given address. - /// - /// Uses [`hyper::server::Server`]'s default configuration. Creating a - /// [`hyper::server::Server`] manually is recommended if different - /// configuration is needed. In that case [`into_make_service`] can be used - /// to easily serve this router. - /// - /// [hyper]: http://crates.io/crates/hyper - /// [`into_make_service`]: RoutingDsl::into_make_service - #[cfg(any(feature = "hyper-h1", feature = "hyper-h2"))] - #[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-h1", feature = "hyper-h2"))))] - async fn serve(self, addr: &std::net::SocketAddr) -> Result<(), hyper::Error> - where - Self: Service, Response = Response, Error = Infallible> - + Clone - + Send - + 'static, - Self::Future: Send, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, - { - hyper::server::Server::bind(addr) - .serve(self.into_make_service()) - .await - } } impl RoutingDsl for Route {} impl crate::sealed::Sealed for Route {} -impl Service> for Route +impl Service> for Route where - S: Service, Response = Response, Error = Infallible> + Clone, - F: Service, Response = Response, Error = Infallible> + Clone, + S: Service, Response = Response, Error = Infallible> + Clone, + F: Service, Response = Response, Error = Infallible> + Clone, { type Response = Response; type Error = Infallible; - type Future = RouteFuture; + type Future = RouteFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, mut req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { if let Some(captures) = self.pattern.full_match(req.uri().path()) { insert_url_params(&mut req, captures); let fut = self.svc.clone().oneshot(req); @@ -319,40 +301,40 @@ where /// The response future for [`Route`]. #[pin_project] #[derive(Debug)] -pub struct RouteFuture(#[pin] RouteFutureInner) +pub struct RouteFuture(#[pin] RouteFutureInner) where - S: Service>, - F: Service>; + S: Service>, + F: Service>; -impl RouteFuture +impl RouteFuture where - S: Service>, - F: Service>, + S: Service>, + F: Service>, { - pub(crate) fn a(a: Oneshot>) -> Self { + pub(crate) fn a(a: Oneshot>) -> Self { RouteFuture(RouteFutureInner::A(a)) } - pub(crate) fn b(b: Oneshot>) -> Self { + pub(crate) fn b(b: Oneshot>) -> Self { RouteFuture(RouteFutureInner::B(b)) } } #[pin_project(project = RouteFutureInnerProj)] #[derive(Debug)] -enum RouteFutureInner +enum RouteFutureInner where - S: Service>, - F: Service>, + S: Service>, + F: Service>, { - A(#[pin] Oneshot>), - B(#[pin] Oneshot>), + A(#[pin] Oneshot>), + B(#[pin] Oneshot>), } -impl Future for RouteFuture +impl Future for RouteFuture where - S: Service, Response = Response, Error = Infallible>, - F: Service, Response = Response, Error = Infallible>, + S: Service, Response = Response, Error = Infallible>, + F: Service, Response = Response, Error = Infallible>, { type Output = Result, Infallible>; @@ -393,7 +375,7 @@ impl RoutingDsl for EmptyRouter {} impl crate::sealed::Sealed for EmptyRouter {} -impl Service> for EmptyRouter { +impl Service> for EmptyRouter { type Response = Response; type Error = Infallible; type Future = EmptyRouterFuture; @@ -402,7 +384,7 @@ impl Service> for EmptyRouter { Poll::Ready(Ok(())) } - fn call(&mut self, _req: Request) -> Self::Future { + fn call(&mut self, _req: Request) -> Self::Future { let mut res = Response::new(BoxBody::empty()); *res.status_mut() = StatusCode::NOT_FOUND; EmptyRouterFuture(future::ok(res)) @@ -509,25 +491,28 @@ type Captures = Vec<(String, String)>; /// A boxed route trait object. /// /// See [`RoutingDsl::boxed`] for more details. -#[derive(Clone)] -pub struct BoxRoute( - Buffer, Response, Infallible>, Request>, -); +pub struct BoxRoute(Buffer, Response, Infallible>, Request>); -impl fmt::Debug for BoxRoute { +impl Clone for BoxRoute { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl fmt::Debug for BoxRoute { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxRoute").finish() } } -impl RoutingDsl for BoxRoute {} +impl RoutingDsl for BoxRoute {} -impl crate::sealed::Sealed for BoxRoute {} +impl crate::sealed::Sealed for BoxRoute {} -impl Service> for BoxRoute { +impl Service> for BoxRoute { type Response = Response; type Error = Infallible; - type Future = BoxRouteFuture; + type Future = BoxRouteFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -535,27 +520,25 @@ impl Service> for BoxRoute { } #[inline] - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { BoxRouteFuture(self.0.clone().oneshot(req)) } } /// The response future for [`BoxRoute`]. #[pin_project] -pub struct BoxRouteFuture(#[pin] InnerFuture); +pub struct BoxRouteFuture(#[pin] InnerFuture); -type InnerFuture = Oneshot< - Buffer, Response, Infallible>, Request>, - Request, ->; +type InnerFuture = + Oneshot, Response, Infallible>, Request>, Request>; -impl fmt::Debug for BoxRouteFuture { +impl fmt::Debug for BoxRouteFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxRouteFuture").finish() } } -impl Future for BoxRouteFuture { +impl Future for BoxRouteFuture { type Output = Result, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -601,12 +584,39 @@ fn handle_buffer_error(error: BoxError) -> Response { /// A [`Service`] created from a router by applying a Tower middleware. /// /// Created with [`RoutingDsl::layer`]. See that method for more details. -#[derive(Clone, Debug)] -pub struct Layered(S); +pub struct Layered { + inner: S, +} + +impl Layered { + fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Clone for Layered +where + S: Clone, +{ + fn clone(&self) -> Self { + Self::new(self.inner.clone()) + } +} + +impl fmt::Debug for Layered +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Layered") + .field("inner", &self.inner) + .finish() + } +} impl RoutingDsl for Layered {} -impl crate::sealed::Sealed for Layered {} +impl crate::sealed::Sealed for Layered {} impl Layered { /// Create a new [`Layered`] service where errors will be handled using the @@ -627,11 +637,11 @@ impl Layered { /// async fn handler() { /* ... */ } /// /// // `Timeout` will fail with `BoxError` if the timeout elapses... - /// let layered_handler = route("/", get(handler)) + /// let layered_app = route("/", get(handler)) /// .layer(TimeoutLayer::new(Duration::from_secs(30))); /// /// // ...so we must handle that error - /// let layered_handler = layered_handler.handle_error(|error: BoxError| { + /// let with_errors_handled = layered_app.handle_error(|error: BoxError| { /// if error.is::() { /// ( /// StatusCode::REQUEST_TIMEOUT, @@ -644,37 +654,47 @@ impl Layered { /// ) /// } /// }); + /// + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()) + /// # .serve(with_errors_handled.into_make_service()) + /// # .await + /// # .unwrap(); + /// # }; /// ``` /// /// The closure can return any type that implements [`IntoResponse`]. - pub fn handle_error(self, f: F) -> crate::service::HandleError + pub fn handle_error( + self, + f: F, + ) -> crate::service::HandleError where - S: Service, Response = Response> + Clone, + S: Service, Response = Response> + Clone, F: FnOnce(S::Error) -> Res, Res: IntoResponse, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into + Send + Sync + 'static, { - crate::service::HandleError { inner: self.0, f } + crate::service::HandleError::new(self.inner, f) } } -impl Service> for Layered +impl Service for Layered where - S: Service, Response = Response, Error = Infallible>, + S: Service, { type Response = S::Response; - type Error = Infallible; + type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.0.poll_ready(cx) + self.inner.poll_ready(cx) } #[inline] - fn call(&mut self, req: Request) -> Self::Future { - self.0.call(req) + fn call(&mut self, req: R) -> Self::Future { + self.inner.call(req) } } @@ -702,7 +722,7 @@ where /// /// let app = nest("/api", users_api).route("/careers", get(careers)); /// # async { -/// # app.serve(&"".parse().unwrap()).await.unwrap(); +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// @@ -723,7 +743,7 @@ where /// /// let app = nest("/:version/api", users_api); /// # async { -/// # app.serve(&"".parse().unwrap()).await.unwrap(); +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// @@ -742,16 +762,16 @@ where /// /// let app = nest("/public", get(serve_dir_service)); /// # async { -/// # app.serve(&"".parse().unwrap()).await.unwrap(); +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// If necessary you can use [`RoutingDsl::boxed`] to box a group of routes /// making the type easier to name. This is sometimes useful when working with /// `nest`. -pub fn nest(description: &str, svc: S) -> Nested +pub fn nest(description: &str, svc: S) -> Nested where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { Nested { pattern: PathPattern::new(description), @@ -774,20 +794,20 @@ impl RoutingDsl for Nested {} impl crate::sealed::Sealed for Nested {} -impl Service> for Nested +impl Service> for Nested where - S: Service, Response = Response, Error = Infallible> + Clone, - F: Service, Response = Response, Error = Infallible> + Clone, + S: Service, Response = Response, Error = Infallible> + Clone, + F: Service, Response = Response, Error = Infallible> + Clone, { type Response = Response; type Error = Infallible; - type Future = RouteFuture; + type Future = RouteFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, mut req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) { let without_prefix = strip_prefix(req.uri(), prefix); *req.uri_mut() = without_prefix; diff --git a/src/service/mod.rs b/src/service/mod.rs index 307bf804..4e9da348 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -20,7 +20,7 @@ //! let app = route("/old", service::get(redirect_service)) //! .route("/new", handler::get(handler)); //! # async { -//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! @@ -70,6 +70,9 @@ //! let app = ServiceBuilder::new() //! .layer(some_backpressure_sensitive_middleware) //! .service(app); +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; //! ``` //! //! However when applying middleware around your whole application in this way @@ -84,7 +87,7 @@ //! [load shed]: tower::load_shed use crate::{ - body::{Body, BoxBody}, + body::BoxBody, response::IntoResponse, routing::{EmptyRouter, MethodFilter, RouteFuture}, }; @@ -96,6 +99,7 @@ use std::{ convert::Infallible, fmt, future::Future, + marker::PhantomData, task::{Context, Poll}, }; use tower::{util::Oneshot, BoxError, Service, ServiceExt as _}; @@ -105,9 +109,9 @@ pub mod future; /// Route requests to the given service regardless of the HTTP method. /// /// See [`get`] for an example. -pub fn any(svc: S) -> OnMethod, EmptyRouter> +pub fn any(svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { on(MethodFilter::Any, svc) } @@ -115,9 +119,9 @@ where /// Route `CONNECT` requests to the given service. /// /// See [`get`] for an example. -pub fn connect(svc: S) -> OnMethod, EmptyRouter> +pub fn connect(svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { on(MethodFilter::Connect, svc) } @@ -125,9 +129,9 @@ where /// Route `DELETE` requests to the given service. /// /// See [`get`] for an example. -pub fn delete(svc: S) -> OnMethod, EmptyRouter> +pub fn delete(svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { on(MethodFilter::Delete, svc) } @@ -148,13 +152,16 @@ where /// /// // Requests to `GET /` will go to `service`. /// let app = route("/", service::get(service)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` /// /// You can only add services who cannot fail (their error type must be /// [`Infallible`]). To gracefully handle errors see [`ServiceExt::handle_error`]. -pub fn get(svc: S) -> OnMethod, EmptyRouter> +pub fn get(svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { on(MethodFilter::Get, svc) } @@ -162,9 +169,9 @@ where /// Route `HEAD` requests to the given service. /// /// See [`get`] for an example. -pub fn head(svc: S) -> OnMethod, EmptyRouter> +pub fn head(svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { on(MethodFilter::Head, svc) } @@ -172,9 +179,9 @@ where /// Route `OPTIONS` requests to the given service. /// /// See [`get`] for an example. -pub fn options(svc: S) -> OnMethod, EmptyRouter> +pub fn options(svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { on(MethodFilter::Options, svc) } @@ -182,9 +189,9 @@ where /// Route `PATCH` requests to the given service. /// /// See [`get`] for an example. -pub fn patch(svc: S) -> OnMethod, EmptyRouter> +pub fn patch(svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { on(MethodFilter::Patch, svc) } @@ -192,9 +199,9 @@ where /// Route `POST` requests to the given service. /// /// See [`get`] for an example. -pub fn post(svc: S) -> OnMethod, EmptyRouter> +pub fn post(svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { on(MethodFilter::Post, svc) } @@ -202,9 +209,9 @@ where /// Route `PUT` requests to the given service. /// /// See [`get`] for an example. -pub fn put(svc: S) -> OnMethod, EmptyRouter> +pub fn put(svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { on(MethodFilter::Put, svc) } @@ -212,9 +219,9 @@ where /// Route `TRACE` requests to the given service. /// /// See [`get`] for an example. -pub fn trace(svc: S) -> OnMethod, EmptyRouter> +pub fn trace(svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { on(MethodFilter::Trace, svc) } @@ -235,14 +242,20 @@ where /// /// // Requests to `POST /` will go to `service`. /// let app = route("/", service::on(MethodFilter::Post, service)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; /// ``` -pub fn on(method: MethodFilter, svc: S) -> OnMethod, EmptyRouter> +pub fn on(method: MethodFilter, svc: S) -> OnMethod, EmptyRouter> where - S: Service, Error = Infallible> + Clone, + S: Service, Error = Infallible> + Clone, { OnMethod { method, - svc: BoxResponseBody(svc), + svc: BoxResponseBody { + inner: svc, + _request_body: PhantomData, + }, fallback: EmptyRouter, } } @@ -261,9 +274,9 @@ impl OnMethod { /// its HTTP method. /// /// See [`OnMethod::get`] for an example. - pub fn any(self, svc: T) -> OnMethod, Self> + pub fn any(self, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { self.on(MethodFilter::Any, svc) } @@ -271,9 +284,9 @@ impl OnMethod { /// Chain an additional service that will only accept `CONNECT` requests. /// /// See [`OnMethod::get`] for an example. - pub fn connect(self, svc: T) -> OnMethod, Self> + pub fn connect(self, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { self.on(MethodFilter::Connect, svc) } @@ -281,9 +294,9 @@ impl OnMethod { /// Chain an additional service that will only accept `DELETE` requests. /// /// See [`OnMethod::get`] for an example. - pub fn delete(self, svc: T) -> OnMethod, Self> + pub fn delete(self, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { self.on(MethodFilter::Delete, svc) } @@ -309,14 +322,17 @@ impl OnMethod { /// // Requests to `GET /` will go to `service` and `POST /` will go to /// // `other_service`. /// let app = route("/", service::post(service).get(other_service)); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; /// ``` /// /// You can only add services who cannot fail (their error type must be /// [`Infallible`]). To gracefully handle errors see /// [`ServiceExt::handle_error`]. - pub fn get(self, svc: T) -> OnMethod, Self> + pub fn get(self, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { self.on(MethodFilter::Get, svc) } @@ -324,9 +340,9 @@ impl OnMethod { /// Chain an additional service that will only accept `HEAD` requests. /// /// See [`OnMethod::get`] for an example. - pub fn head(self, svc: T) -> OnMethod, Self> + pub fn head(self, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { self.on(MethodFilter::Head, svc) } @@ -334,9 +350,9 @@ impl OnMethod { /// Chain an additional service that will only accept `OPTIONS` requests. /// /// See [`OnMethod::get`] for an example. - pub fn options(self, svc: T) -> OnMethod, Self> + pub fn options(self, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { self.on(MethodFilter::Options, svc) } @@ -344,9 +360,9 @@ impl OnMethod { /// Chain an additional service that will only accept `PATCH` requests. /// /// See [`OnMethod::get`] for an example. - pub fn patch(self, svc: T) -> OnMethod, Self> + pub fn patch(self, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { self.on(MethodFilter::Patch, svc) } @@ -354,9 +370,9 @@ impl OnMethod { /// Chain an additional service that will only accept `POST` requests. /// /// See [`OnMethod::get`] for an example. - pub fn post(self, svc: T) -> OnMethod, Self> + pub fn post(self, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { self.on(MethodFilter::Post, svc) } @@ -364,9 +380,9 @@ impl OnMethod { /// Chain an additional service that will only accept `PUT` requests. /// /// See [`OnMethod::get`] for an example. - pub fn put(self, svc: T) -> OnMethod, Self> + pub fn put(self, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { self.on(MethodFilter::Put, svc) } @@ -374,9 +390,9 @@ impl OnMethod { /// Chain an additional service that will only accept `TRACE` requests. /// /// See [`OnMethod::get`] for an example. - pub fn trace(self, svc: T) -> OnMethod, Self> + pub fn trace(self, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { self.on(MethodFilter::Trace, svc) } @@ -402,14 +418,20 @@ impl OnMethod { /// /// // Requests to `DELETE /` will go to `service` /// let app = route("/", service::on(MethodFilter::Delete, service)); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; /// ``` - pub fn on(self, method: MethodFilter, svc: T) -> OnMethod, Self> + pub fn on(self, method: MethodFilter, svc: T) -> OnMethod, Self> where - T: Service, Error = Infallible> + Clone, + T: Service, Error = Infallible> + Clone, { OnMethod { method, - svc: BoxResponseBody(svc), + svc: BoxResponseBody { + inner: svc, + _request_body: PhantomData, + }, fallback: self, } } @@ -417,20 +439,20 @@ impl OnMethod { // this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean // that up, but not sure its possible. -impl Service> for OnMethod +impl Service> for OnMethod where - S: Service, Response = Response, Error = Infallible> + Clone, - F: Service, Response = Response, Error = Infallible> + Clone, + S: Service, Response = Response, Error = Infallible> + Clone, + F: Service, Response = Response, Error = Infallible> + Clone, { type Response = Response; type Error = Infallible; - type Future = RouteFuture; + type Future = RouteFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { if self.method.matches(req.method()) { let fut = self.svc.clone().oneshot(req); RouteFuture::a(fut) @@ -447,23 +469,37 @@ where /// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or /// [`routing::Layered::handle_error`](crate::routing::Layered::handle_error). /// See those methods for more details. -#[derive(Clone)] -pub struct HandleError { - pub(crate) inner: S, - pub(crate) f: F, +pub struct HandleError { + inner: S, + f: F, + _marker: PhantomData B>, } -impl crate::routing::RoutingDsl for HandleError {} - -impl crate::sealed::Sealed for HandleError {} - -impl HandleError { - pub(crate) fn new(inner: S, f: F) -> Self { - Self { inner, f } +impl Clone for HandleError +where + S: Clone, + F: Clone, +{ + fn clone(&self) -> Self { + Self::new(self.inner.clone(), self.f.clone()) } } -impl fmt::Debug for HandleError +impl crate::routing::RoutingDsl for HandleError {} + +impl crate::sealed::Sealed for HandleError {} + +impl HandleError { + pub(crate) fn new(inner: S, f: F) -> Self { + Self { + inner, + f, + _marker: PhantomData, + } + } +} + +impl fmt::Debug for HandleError where S: fmt::Debug, { @@ -475,23 +511,23 @@ where } } -impl Service> for HandleError +impl Service> for HandleError where - S: Service, Response = Response> + Clone, + S: Service, Response = Response> + Clone, F: FnOnce(S::Error) -> Res + Clone, Res: IntoResponse, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into + Send + Sync + 'static, { type Response = Response; type Error = Infallible; - type Future = future::HandleErrorFuture>, F>; + type Future = future::HandleErrorFuture>, F>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { future::HandleErrorFuture { f: Some(self.f.clone()), inner: self.inner.clone().oneshot(req), @@ -500,7 +536,9 @@ where } /// Extension trait that adds additional methods to [`Service`]. -pub trait ServiceExt: Service, Response = Response> { +pub trait ServiceExt: + Service, Response = Response> +{ /// Handle errors from a service. /// /// awebframework requires all handlers and services, that are part of the @@ -533,43 +571,71 @@ pub trait ServiceExt: Service, Response = Response> { /// ); /// # /// # async { - /// # app.serve(&"".parse().unwrap()).await.unwrap(); + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - fn handle_error(self, f: F) -> HandleError + fn handle_error(self, f: F) -> HandleError where Self: Sized, F: FnOnce(Self::Error) -> Res, Res: IntoResponse, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into + Send + Sync + 'static, { HandleError::new(self, f) } } -impl ServiceExt for S where S: Service, Response = Response> {} +impl ServiceExt for S where + S: Service, Response = Response> +{ +} /// A [`Service`] that boxes response bodies. -#[derive(Debug, Clone)] -pub struct BoxResponseBody(S); +pub struct BoxResponseBody { + inner: S, + _request_body: PhantomData B>, +} -impl Service> for BoxResponseBody +impl Clone for BoxResponseBody where - S: Service, Response = Response, Error = Infallible> + Clone, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, + S: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + _request_body: PhantomData, + } + } +} + +impl fmt::Debug for BoxResponseBody +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BoxResponseBody") + .field("inner", &self.inner) + .finish() + } +} + +impl Service> for BoxResponseBody +where + S: Service, Response = Response, Error = Infallible> + Clone, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into + Send + Sync + 'static, { type Response = Response; type Error = Infallible; - type Future = BoxResponseBodyFuture>>; + type Future = BoxResponseBodyFuture>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { - let fut = self.0.clone().oneshot(req); + fn call(&mut self, req: Request) -> Self::Future { + let fut = self.inner.clone().oneshot(req); BoxResponseBodyFuture(fut) } } diff --git a/src/tests.rs b/src/tests.rs index 50a48ffd..fc0e588e 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -8,7 +8,7 @@ use std::{ net::{SocketAddr, TcpListener}, time::Duration, }; -use tower::{make::Shared, BoxError, Service, ServiceBuilder}; +use tower::{make::Shared, service_fn, BoxError, Service, ServiceBuilder}; use tower_http::{compression::CompressionLayer, trace::TraceLayer}; #[tokio::test] @@ -329,7 +329,6 @@ async fn boxing() { #[tokio::test] async fn service_handlers() { use crate::service::ServiceExt as _; - use tower::service_fn; use tower_http::services::ServeFile; let app = route( @@ -607,6 +606,53 @@ async fn typed_header() { assert_eq!(body, "invalid HTTP header (user-agent)"); } +#[tokio::test] +async fn different_request_body_types() { + use http_body::{Empty, Full}; + use std::convert::Infallible; + use tower_http::map_request_body::MapRequestBodyLayer; + + async fn handler(body: String) -> String { + body + } + + async fn svc_handler(req: Request) -> Result, Infallible> + where + B: http_body::Body, + B::Error: std::fmt::Debug, + { + let body = hyper::body::to_bytes(req.into_body()).await.unwrap(); + Ok(Response::new(Body::from(body))) + } + + let app = route("/", service::get(service_fn(svc_handler))) + .route( + "/foo", + get(handler.layer(MapRequestBodyLayer::new(|_| Full::::from("foo")))), + ) + .layer(MapRequestBodyLayer::new(|_| Empty::::new())); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/", addr)) + .send() + .await + .unwrap(); + let body = res.text().await.unwrap(); + assert_eq!(body, ""); + + let res = client + .get(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + let body = res.text().await.unwrap(); + assert_eq!(body, "foo"); +} + /// Run a `tower::Service` in the background and get a URI for it. async fn run_in_background(svc: S) -> SocketAddr where diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 29269645..ab63cf00 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -13,6 +13,9 @@ //! socket.send(msg).await.unwrap(); //! } //! } +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; //! ``` use crate::{ @@ -41,7 +44,7 @@ pub mod future; /// each connection. /// /// See the [module docs](crate::ws) for more details. -pub fn ws(callback: F) -> OnMethod>, EmptyRouter> +pub fn ws(callback: F) -> OnMethod, B>, EmptyRouter> where F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static, Fut: Future + Send + 'static, @@ -50,7 +53,7 @@ where callback, config: WebSocketConfig::default(), }; - crate::service::get(svc) + crate::service::get::<_, B>(svc) } /// [`Service`] that ugprades connections to websockets and spawns a task to