From a005427d40523a625a79a89a394f182077491910 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 6 Jun 2021 15:19:54 +0200 Subject: [PATCH] Write some docs --- Cargo.toml | 3 +- src/extract/mod.rs | 4 +- src/handler.rs | 88 +++++++- src/lib.rs | 513 +++++++++++++++++++++++++++++++++++++++++++++ src/response.rs | 51 ++--- src/routing.rs | 84 +------- src/service.rs | 73 ++++++- src/tests.rs | 85 +++++++- 8 files changed, 773 insertions(+), 128 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e67dd602..0013785b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,9 +26,10 @@ hyper = { version = "0.14", features = ["full"] } reqwest = { version = "0.11", features = ["json", "stream"] } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] } -tower = { version = "0.4", features = ["util", "make", "timeout"] } +tower = { version = "0.4", features = ["util", "make", "timeout", "limit", "load-shed"] } tracing = "0.1" tracing-subscriber = "0.2" +uuid = "0.8" [dev-dependencies.tower-http] version = "0.1" diff --git a/src/extract/mod.rs b/src/extract/mod.rs index b673d015..8c19fe59 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -31,7 +31,7 @@ where } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Default)] pub struct Query(pub T); #[async_trait] @@ -48,7 +48,7 @@ where } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Default)] pub struct Json(pub T); #[async_trait] diff --git a/src/handler.rs b/src/handler.rs index 4506a77b..67d4fb78 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,9 +1,9 @@ use crate::{ - body::Body, + body::{Body, BoxBody}, extract::FromRequest, response::IntoResponse, - routing::{EmptyRouter, MethodFilter, OnMethod}, - service::{self, HandleError}, + routing::{BoxResponseBody, EmptyRouter, MethodFilter}, + service::HandleError, }; use async_trait::async_trait; use bytes::Bytes; @@ -15,7 +15,7 @@ use std::{ marker::PhantomData, task::{Context, Poll}, }; -use tower::{BoxError, Layer, Service, ServiceExt}; +use tower::{util::Oneshot, BoxError, Layer, Service, ServiceExt}; pub fn get(handler: H) -> OnMethod, EmptyRouter> where @@ -35,7 +35,11 @@ pub fn on(method: MethodFilter, handler: H) -> OnMethod, { - service::on(method, handler.into_service()) + OnMethod { + method, + svc: handler.into_service(), + fallback: EmptyRouter, + } } mod sealed { @@ -236,3 +240,77 @@ where Box::pin(async move { Ok(Handler::call(handler, req).await) }) } } + +#[derive(Clone)] +pub struct OnMethod { + pub(crate) method: MethodFilter, + pub(crate) svc: S, + pub(crate) fallback: F, +} + +impl OnMethod { + pub fn get(self, handler: H) -> OnMethod, Self> + where + H: Handler, + { + self.on(MethodFilter::Get, handler) + } + + pub fn post(self, handler: H) -> OnMethod, Self> + where + H: Handler, + { + self.on(MethodFilter::Post, handler) + } + + pub fn on( + self, + method: MethodFilter, + handler: H, + ) -> OnMethod, Self> + where + H: Handler, + { + OnMethod { + method, + svc: handler.into_service(), + fallback: self, + } + } +} + +// this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean +// that up, but not sure its possible. +impl Service> for OnMethod +where + S: Service, Response = Response, Error = Infallible> + Clone, + SB: http_body::Body + Send + Sync + 'static, + SB::Error: Into, + + F: Service, Response = Response, Error = Infallible> + Clone, + FB: http_body::Body + Send + Sync + 'static, + FB::Error: Into, +{ + type Response = Response; + type Error = Infallible; + + #[allow(clippy::type_complexity)] + type Future = future::Either< + BoxResponseBody>>, + BoxResponseBody>>, + >; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + if self.method.matches(req.method()) { + let response_future = self.svc.clone().oneshot(req); + future::Either::Left(BoxResponseBody(response_future)) + } else { + let response_future = self.fallback.clone().oneshot(req); + future::Either::Right(BoxResponseBody(response_future)) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 06492fae..2196c811 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,512 @@ +//! tower-web (name pending) is a tiny web application framework that focuses on +//! ergonimics and modularity. +//! +//! ## Goals +//! +//! - Ease of use. Build web apps in Rust should be as easy as `async fn +//! handle(Request) -> Response`. +//! - Solid foundation. tower-web is built on top of tower and makes it easy to +//! plug in any middleware from the [tower] and [tower-http] ecosystem. +//! - Focus on routing, extracing data from requests, and generating responses. +//! tower middleware can handle the rest. +//! - Macro free core. Macro frameworks have their place but tower-web focuses +//! on providing a core that is macro free. +//! +//! ## Non-goals +//! +//! - Runtime independent. tower-web is designed to work with tokio and hyper +//! and focused on bringing a good to experience to that stack. +//! - Speed. tower-web is a of course a fast framework, and wont be the +//! bottleneck in your app, but the goal is not to top the benchmarks. +//! +//! # Example +//! +//! The "Hello, World!" of tower-web is: +//! +//! ```rust,no_run +//! use tower_web::prelude::*; +//! use hyper::Server; +//! use std::net::SocketAddr; +//! use tower::make::Shared; +//! +//! #[tokio::main] +//! async fn main() { +//! // build our application with a single route +//! let app = route("/", get(handler)); +//! +//! // run it with hyper on localhost:3000 +//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); +//! let server = Server::bind(&addr).serve(Shared::new(app)); +//! server.await.unwrap(); +//! } +//! +//! async fn handler(req: Request) -> &'static str { +//! "Hello, World!" +//! } +//! ``` +//! +//! # Routing +//! +//! Routing between handlers looks like this: +//! +//! ```rust,no_run +//! use tower_web::prelude::*; +//! +//! let app = route("/", get(get_slash).post(post_slash)) +//! .route("/foo", get(get_foo)); +//! +//! async fn get_slash(req: Request) { +//! // `GET /` called +//! } +//! +//! async fn post_slash(req: Request) { +//! // `POST /` called +//! } +//! +//! async fn get_foo(req: Request) { +//! // `GET /foo` called +//! } +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! Routes can also be dynamic like `/users/:id`. See ["Extracting data from +//! requests"](#extracting-data-from-requests) for more details on that. +//! +//! # Responses +//! +//! Anything that implements [`IntoResponse`] can be returned from a handler: +//! +//! ```rust,no_run +//! use tower_web::{body::Body, response::{Html, Json}, prelude::*}; +//! use http::{StatusCode, Response}; +//! use serde_json::{Value, json}; +//! +//! // We've already seen returning &'static str +//! async fn plain_text(req: Request) -> &'static str { +//! "foo" +//! } +//! +//! // String works too and will get a text/plain content-type +//! async fn plain_text_string(req: Request) -> String { +//! format!("Hi from {}", req.uri().path()) +//! } +//! +//! // Bytes will get a `application/octet-stream` content-type +//! async fn bytes(req: Request) -> Vec { +//! vec![1, 2, 3, 4] +//! } +//! +//! // `()` gives an empty response +//! async fn empty(req: Request) {} +//! +//! // `StatusCode` gives an empty response with that status code +//! async fn empty_with_status(req: Request) -> StatusCode { +//! StatusCode::NOT_FOUND +//! } +//! +//! // A tuple of `StatusCode` and something that implements `IntoResponse` can +//! // be used to override the status code +//! async fn with_status(req: Request) -> (StatusCode, &'static str) { +//! (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong") +//! } +//! +//! // `Html` gives a content-type of `text/html` +//! async fn html(req: Request) -> Html<&'static str> { +//! Html("

Hello, World!

") +//! } +//! +//! // `Json` gives a content-type of `application/json` and works with my type +//! // that implements `serde::Serialize` +//! async fn json(req: Request) -> Json { +//! Json(json!({ "data": 42 })) +//! } +//! +//! // `Result` where `T` and `E` implement `IntoResponse` is useful for +//! // returning errors +//! async fn result(req: Request) -> Result<&'static str, StatusCode> { +//! Ok("all good") +//! } +//! +//! // `Response` gives full control +//! async fn response(req: Request) -> Response { +//! Response::builder().body(Body::empty()).unwrap() +//! } +//! +//! let app = route("/plain_text", get(plain_text)) +//! .route("/plain_text_string", get(plain_text_string)) +//! .route("/bytes", get(bytes)) +//! .route("/empty", get(empty)) +//! .route("/empty_with_status", get(empty_with_status)) +//! .route("/with_status", get(with_status)) +//! .route("/html", get(html)) +//! .route("/json", get(json)) +//! .route("/result", get(result)) +//! .route("/response", get(response)); +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! See the [`response`] module for more details. +//! +//! # Extracting data from requests +//! +//! A handler function must always take `Request` as its first argument +//! but any arguments following are called "extractors". Any type that +//! implements [`FromRequest`](crate::extract::FromRequest) can be used as an +//! extractor. +//! +//! [`extract::Json`] is an extractor that consumes the request body and +//! deserializes as as JSON into some target type: +//! +//! ```rust,no_run +//! use tower_web::prelude::*; +//! use serde::Deserialize; +//! +//! let app = route("/users", post(create_user)); +//! +//! #[derive(Deserialize)] +//! struct CreateUser { +//! email: String, +//! password: String, +//! } +//! +//! async fn create_user(req: Request, payload: extract::Json) { +//! let payload: CreateUser = payload.0; +//! +//! // ... +//! } +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! [`extract::UrlParams`] can be used to extract params from a dynamic URL. It +//! is compatible with any type that implements [`std::str::FromStr`], such as +//! [`Uuid`]: +//! +//! ```rust,no_run +//! use tower_web::prelude::*; +//! use uuid::Uuid; +//! +//! let app = route("/users/:id", post(create_user)); +//! +//! async fn create_user(req: Request, params: extract::UrlParams<(Uuid,)>) { +//! let (user_id,) = params.0; +//! +//! // ... +//! } +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! There is also [`UrlParamsMap`](extract::UrlParamsMap) which provide a map +//! like API for extracting URL params. +//! +//! You can also apply multiple extractors: +//! +//! ```rust,no_run +//! use tower_web::prelude::*; +//! use uuid::Uuid; +//! use serde::Deserialize; +//! +//! let app = route("/users/:id/things", get(get_user_things)); +//! +//! #[derive(Deserialize)] +//! struct Pagination { +//! page: usize, +//! per_page: usize, +//! } +//! +//! impl Default for Pagination { +//! fn default() -> Self { +//! Self { page: 1, per_page: 30 } +//! } +//! } +//! +//! async fn get_user_things( +//! req: Request, +//! params: extract::UrlParams<(Uuid,)>, +//! pagination: Option>, +//! ) { +//! let user_id: Uuid = (params.0).0; +//! let pagination: Pagination = pagination.unwrap_or_default().0; +//! +//! // ... +//! } +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! See the [`extract`] module for more details. +//! +//! [`Uuid`]: https://docs.rs/uuid/latest/uuid/ +//! +//! # Applying middleware +//! +//! tower-web is designed to take full advantage of the tower and tower-http +//! ecosystem of middleware: +//! +//! ## To individual handlers +//! +//! A middleware can be applied to a single handler like so: +//! +//! ```rust,no_run +//! use tower_web::prelude::*; +//! use tower::limit::ConcurrencyLimitLayer; +//! +//! let app = route( +//! "/", +//! get(handler.layer(ConcurrencyLimitLayer::new(100))), +//! ); +//! +//! async fn handler(req: Request) {} +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! ## To groups of routes +//! +//! Middleware can also be applied to a group of routes like so: +//! +//! ```rust,no_run +//! use tower_web::prelude::*; +//! use tower::limit::ConcurrencyLimitLayer; +//! +//! let app = route("/", get(get_slash)) +//! .route("/foo", post(post_foo)) +//! .layer(ConcurrencyLimitLayer::new(100)); +//! +//! async fn get_slash(req: Request) {} +//! +//! async fn post_foo(req: Request) {} +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! ## Error handling +//! +//! tower-web requires all errors to be handled. That is done by using +//! [`std::convert::Infallible`] as the error type in all its [`Service`] +//! implementations. +//! +//! For handlers created from async functions this is works automatically since +//! handlers must return something that implements [`IntoResponse`], even if its +//! a `Result`. +//! +//! However middleware might add new failure cases that has to be handled. For +//! that tower-web provides a `handle_error` combinator: +//! +//! ```rust,no_run +//! use tower_web::prelude::*; +//! use tower::{ +//! BoxError, timeout::{TimeoutLayer, error::Elapsed}, +//! }; +//! use std::{borrow::Cow, time::Duration}; +//! use http::StatusCode; +//! +//! let app = route( +//! "/", +//! get(handle +//! .layer(TimeoutLayer::new(Duration::from_secs(30))) +//! // `Timeout` uses `BoxError` as the error type +//! .handle_error(|error: BoxError| { +//! // Check if the actual error type is `Elapsed` which +//! // `Timeout` returns +//! if error.is::() { +//! return (StatusCode::REQUEST_TIMEOUT, "Request took too long".into()); +//! } +//! +//! // If we encounter some error we don't handle return a generic +//! // error +//! return ( +//! StatusCode::INTERNAL_SERVER_ERROR, +//! // `Cow` lets us return either `&str` or `String` +//! Cow::from(format!("Unhandled internal error: {}", error)), +//! ); +//! })), +//! ); +//! +//! async fn handle(req: Request) {} +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! The closure passed to `handle_error` must return something that implements +//! `IntoResponse`. +//! +//! `handle_error` is also available on a group of routes with middleware +//! applied: +//! +//! ```rust,no_run +//! use tower_web::prelude::*; +//! use tower::{ +//! BoxError, timeout::{TimeoutLayer, error::Elapsed}, +//! }; +//! use std::{borrow::Cow, time::Duration}; +//! use http::StatusCode; +//! +//! let app = route("/", get(handle)) +//! .layer(TimeoutLayer::new(Duration::from_secs(30))) +//! .handle_error(|error: BoxError| { +//! // ... +//! }); +//! +//! async fn handle(req: Request) {} +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! ## Applying multiple middleware +//! +//! [`tower::ServiceBuilder`] can be used to combine multiple middleware: +//! +//! ```rust,no_run +//! use tower_web::prelude::*; +//! use tower::{ +//! ServiceBuilder, BoxError, +//! load_shed::error::Overloaded, +//! timeout::error::Elapsed, +//! }; +//! use tower_http::compression::CompressionLayer; +//! use std::{borrow::Cow, time::Duration}; +//! use http::StatusCode; +//! +//! let middleware_stack = ServiceBuilder::new() +//! // Return an error after 30 seconds +//! .timeout(Duration::from_secs(30)) +//! // Shed load if we're receiving too many requests +//! .load_shed() +//! // Process at most 100 requests concurrently +//! .concurrency_limit(100) +//! // Compress response bodies +//! .layer(CompressionLayer::new()) +//! .into_inner(); +//! +//! let app = route("/", get(|_: Request| async { /* ... */ })) +//! .layer(middleware_stack) +//! .handle_error(|error: BoxError| { +//! if error.is::() { +//! return ( +//! StatusCode::SERVICE_UNAVAILABLE, +//! "Try again later".into(), +//! ); +//! } +//! +//! if error.is::() { +//! return ( +//! StatusCode::REQUEST_TIMEOUT, +//! "Request took too long".into(), +//! ); +//! }; +//! +//! return ( +//! StatusCode::INTERNAL_SERVER_ERROR, +//! Cow::from(format!("Unhandled internal error: {}", error)), +//! ); +//! }); +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! # Sharing state with handlers +//! +//! It is common to share some state between handlers for example to share a +//! pool of database connections or clients to other services. That can be done +//! using the [`AddExtension`] middleware (applied with [`AddExtensionLayer`]) +//! and the [`extract::Extension`] extractor: +//! +//! ```rust,no_run +//! use tower_web::{AddExtensionLayer, prelude::*}; +//! use std::sync::Arc; +//! +//! struct State { +//! // ... +//! } +//! +//! let shared_state = Arc::new(State { /* ... */ }); +//! +//! let app = route("/", get(handler)).layer(AddExtensionLayer::new(shared_state)); +//! +//! async fn handler( +//! req: Request, +//! state: extract::Extension>, +//! ) { +//! let state: Arc = state.0; +//! +//! // ... +//! } +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! # Routing to any [`Service`] +//! +//! tower-web also supports routing to general [`Service`]s: +//! +//! ```rust,no_run +//! use tower_web::{ +//! service, prelude::*, +//! // `ServiceExt` adds `handle_error` to any `Service` +//! ServiceExt, +//! }; +//! use tower_http::services::ServeFile; +//! use http::Response; +//! use std::convert::Infallible; +//! use tower::{service_fn, BoxError}; +//! +//! let app = route( +//! // Any request to `/` goes to a service +//! "/", +//! service_fn(|_: Request| async { +//! let res = Response::new(Body::from("Hi from `GET /`")); +//! Ok::<_, Infallible>(res) +//! }) +//! ).route( +//! // GET `/static/Cargo.toml` goes to a service from tower-http +//! "/static/Cargo.toml", +//! service::get( +//! ServeFile::new("Cargo.toml") +//! // Errors must be handled +//! .handle_error(|error: std::io::Error| { /* ... */ }) +//! ) +//! ); +//! # +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(tower::make::Shared::new(app)).await; +//! # }; +//! ``` +//! +//! See the [`service`] module for more details. +//! +//! # Nesting applications +//! +//! TODO +//! +//! [tower]: https://crates.io/crates/tower +//! [tower-http]: https://crates.io/crates/tower-http + // #![doc(html_root_url = "https://docs.rs/tower-http/0.1.0")] #![warn( clippy::all, @@ -76,6 +585,7 @@ pub mod prelude { response, route, routing::AddRoute, }; + pub use http::Request; } pub fn route(spec: &str, svc: S) -> Route @@ -102,6 +612,9 @@ impl ResultExt for Result { } pub trait ServiceExt: Service, Response = Response> { + // TODO(david): routing methods like get, post, etc like whats on OnMethod + // so you can do `route("...", service::get(svc).post(svc))` + fn handle_error(self, f: F) -> service::HandleError where Self: Sized, diff --git a/src/response.rs b/src/response.rs index ee0e8260..3c821bc4 100644 --- a/src/response.rs +++ b/src/response.rs @@ -2,7 +2,7 @@ use crate::Body; use bytes::Bytes; use http::{header, HeaderMap, HeaderValue, Response, StatusCode}; use serde::Serialize; -use std::convert::Infallible; +use std::{borrow::Cow, convert::Infallible}; use tower::util::Either; // TODO(david): can we change this to not be generic over the body and just use hyper::Body? @@ -10,12 +10,9 @@ pub trait IntoResponse { fn into_response(self) -> Response; } -impl IntoResponse for () -where - B: Default, -{ - fn into_response(self) -> Response { - Response::new(B::default()) +impl IntoResponse for () { + fn into_response(self) -> Response { + Response::new(Body::empty()) } } @@ -58,23 +55,25 @@ impl IntoResponse for Response { } impl IntoResponse for &'static str { + #[inline] fn into_response(self) -> Response { - Response::new(Body::from(self)) + Cow::Borrowed(self).into_response() } } impl IntoResponse for String { + #[inline] fn into_response(self) -> Response { - let mut res = Response::new(Body::from(self)); - res.headers_mut() - .insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain")); - res + Cow::<'static, str>::Owned(self).into_response() } } impl IntoResponse for std::borrow::Cow<'static, str> { fn into_response(self) -> Response { - Response::new(Body::from(self)) + let mut res = Response::new(Body::from(self)); + res.headers_mut() + .insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain")); + res } } @@ -122,12 +121,12 @@ impl IntoResponse for std::borrow::Cow<'static, [u8]> { } } -impl IntoResponse for StatusCode -where - B: Default, -{ - fn into_response(self) -> Response { - Response::builder().status(self).body(B::default()).unwrap() +impl IntoResponse for StatusCode { + fn into_response(self) -> Response { + Response::builder() + .status(self) + .body(Body::empty()) + .unwrap() } } @@ -194,17 +193,3 @@ where res } } - -pub struct Text(pub T); - -impl IntoResponse for Text -where - T: Into, -{ - fn into_response(self) -> Response { - let mut res = Response::new(self.0.into()); - res.headers_mut() - .insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain")); - res - } -} diff --git a/src/routing.rs b/src/routing.rs index d11d7142..5dfefd91 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -1,9 +1,4 @@ -use crate::{ - body::BoxBody, - handler::{self, Handler}, - response::IntoResponse, - ResultExt, -}; +use crate::{body::BoxBody, response::IntoResponse, ResultExt}; use bytes::Bytes; use futures_util::{future, ready}; use http::{Method, Request, Response, StatusCode}; @@ -43,7 +38,7 @@ pub enum MethodFilter { impl MethodFilter { #[allow(clippy::match_like_matches_macro)] - fn matches(self, method: &Method) -> bool { + pub(crate) fn matches(self, method: &Method) -> bool { match (self, method) { (MethodFilter::Any, _) | (MethodFilter::Connect, &Method::CONNECT) @@ -67,13 +62,6 @@ pub struct Route { pub(crate) fallback: F, } -#[derive(Clone)] -pub struct OnMethod { - pub(crate) method: MethodFilter, - pub(crate) svc: S, - pub(crate) fallback: F, -} - pub trait AddRoute: Sized { fn route(self, spec: &str, svc: T) -> Route where @@ -116,30 +104,6 @@ impl AddRoute for Route { } } -impl OnMethod { - pub fn get(self, handler: H) -> OnMethod, Self> - where - H: Handler, - { - self.on_method(MethodFilter::Get, handler.into_service()) - } - - pub fn post(self, handler: H) -> OnMethod, Self> - where - H: Handler, - { - self.on_method(MethodFilter::Post, handler.into_service()) - } - - pub fn on_method(self, method: MethodFilter, svc: T) -> OnMethod { - OnMethod { - method, - svc, - fallback: self, - } - } -} - // ===== Routing service impls ===== impl Service> for Route @@ -190,42 +154,8 @@ fn insert_url_params(req: &mut Request, params: Vec<(String, String)>) { } } -impl Service> for OnMethod -where - S: Service, Response = Response, Error = Infallible> + Clone, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, - - F: Service, Response = Response, Error = Infallible> + Clone, - FB: http_body::Body + Send + Sync + 'static, - FB::Error: Into, -{ - type Response = Response; - type Error = Infallible; - - #[allow(clippy::type_complexity)] - type Future = future::Either< - BoxResponseBody>>, - BoxResponseBody>>, - >; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - if self.method.matches(req.method()) { - let response_future = self.svc.clone().oneshot(req); - future::Either::Left(BoxResponseBody(response_future)) - } else { - let response_future = self.fallback.clone().oneshot(req); - future::Either::Right(BoxResponseBody(response_future)) - } - } -} - #[pin_project] -pub struct BoxResponseBody(#[pin] F); +pub struct BoxResponseBody(#[pin] pub(crate) F); impl Future for BoxResponseBody where @@ -453,7 +383,7 @@ impl AddRoute for Layered { } impl Layered { - pub fn handle_error(self, f: F) -> HandleError + pub fn handle_error(self, f: F) -> HandleError where S: Service, Response = Response> + Clone, F: FnOnce(S::Error) -> Res, @@ -461,16 +391,16 @@ impl Layered { B: http_body::Body + Send + Sync + 'static, B::Error: Into + Send + Sync + 'static, { - HandleError { inner: self, f } + HandleError { inner: self.0, f } } } impl Service> for Layered where - S: Service, Response = Response>, + S: Service, Response = Response, Error = Infallible>, { type Response = S::Response; - type Error = S::Error; + type Error = Infallible; type Future = S::Future; #[inline] diff --git a/src/service.rs b/src/service.rs index 310c0faa..128bf879 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,9 +1,10 @@ use crate::{ body::{Body, BoxBody}, response::IntoResponse, - routing::{EmptyRouter, MethodFilter, OnMethod}, + routing::{BoxResponseBody, EmptyRouter, MethodFilter}, }; use bytes::Bytes; +use futures_util::future; use futures_util::ready; use http::{Request, Response}; use pin_project::pin_project; @@ -32,6 +33,76 @@ pub fn on(method: MethodFilter, svc: S) -> OnMethod { } } +#[derive(Clone)] +pub struct OnMethod { + pub(crate) method: MethodFilter, + pub(crate) svc: S, + pub(crate) fallback: F, +} + +impl OnMethod { + pub fn get(self, svc: T) -> OnMethod + where + T: Service> + Clone, + { + self.on(MethodFilter::Get, svc) + } + + pub fn post(self, svc: T) -> OnMethod + where + T: Service> + Clone, + { + self.on(MethodFilter::Post, svc) + } + + pub fn on(self, method: MethodFilter, svc: T) -> OnMethod + where + T: Service> + Clone, + { + OnMethod { + method, + svc, + fallback: self, + } + } +} + +// this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean +// that up, but not sure its possible. +impl Service> for OnMethod +where + S: Service, Response = Response, Error = Infallible> + Clone, + SB: http_body::Body + Send + Sync + 'static, + SB::Error: Into, + + F: Service, Response = Response, Error = Infallible> + Clone, + FB: http_body::Body + Send + Sync + 'static, + FB::Error: Into, +{ + type Response = Response; + type Error = Infallible; + + #[allow(clippy::type_complexity)] + type Future = future::Either< + BoxResponseBody>>, + BoxResponseBody>>, + >; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + if self.method.matches(req.method()) { + let response_future = self.svc.clone().oneshot(req); + future::Either::Left(BoxResponseBody(response_future)) + } else { + let response_future = self.fallback.clone().oneshot(req); + future::Either::Right(BoxResponseBody(response_future)) + } + } +} + #[derive(Clone)] pub struct HandleError { inner: S, diff --git a/src/tests.rs b/src/tests.rs index ca96b30c..efb7effc 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,4 +1,4 @@ -use crate::{extract, get, post, route, routing::MethodFilter, service, AddRoute, Handler}; +use crate::{extract, get, on, post, route, routing::MethodFilter, service, AddRoute, Handler}; use http::{Request, Response, StatusCode}; use hyper::{Body, Server}; use serde::Deserialize; @@ -276,8 +276,12 @@ async fn extracting_url_params() { async fn boxing() { let app = route( "/", - get(|_: Request| async { "hi from GET" }) - .post(|_: Request| async { "hi from POST" }), + on(MethodFilter::Get, |_: Request| async { + "hi from GET" + }) + .on(MethodFilter::Post, |_: Request| async { + "hi from POST" + }), ) .boxed(); @@ -307,12 +311,9 @@ async fn service_handlers() { let app = route( "/echo", - service::on( - MethodFilter::Post, - service_fn(|req: Request| async move { - Ok::<_, Infallible>(Response::new(req.into_body())) - }), - ), + service::post(service_fn(|req: Request| async move { + Ok::<_, Infallible>(Response::new(req.into_body())) + })), ) .route( "/static/Cargo.toml", @@ -347,6 +348,72 @@ async fn service_handlers() { assert!(res.text().await.unwrap().contains("edition =")); } +#[tokio::test] +async fn routing_between_services() { + use std::convert::Infallible; + use tower::service_fn; + + async fn handle(_: Request) -> &'static str { + "handler" + } + + let app = route( + "/one", + service::get(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::from("one get"))) + })) + .post(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::from("one post"))) + })) + .on( + MethodFilter::Put, + service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::from("one put"))) + }), + ), + ) + .route( + "/two", + service::on(MethodFilter::Get, handle.into_service()), + ); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/one", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "one get"); + + let res = client + .post(format!("http://{}/one", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "one post"); + + let res = client + .put(format!("http://{}/one", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "one put"); + + let res = client + .get(format!("http://{}/two", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "handler"); +} + #[tokio::test] async fn middleware_on_single_route() { use tower::ServiceBuilder;