Move things around a bit

This commit is contained in:
David Pedersen 2021-06-06 11:37:08 +02:00
parent c3977d0b71
commit 46398afc72
10 changed files with 446 additions and 346 deletions

View file

@ -2,7 +2,7 @@ use http::{Request, StatusCode};
use hyper::Server; use hyper::Server;
use std::net::SocketAddr; use std::net::SocketAddr;
use tower::make::Shared; use tower::make::Shared;
use tower_web::{body::Body, response, get, route, AddRoute, extract}; use tower_web::{body::Body, extract, get, response, route, AddRoute};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {

View file

@ -9,8 +9,6 @@ use tower::BoxError;
pub use hyper::body::Body; pub use hyper::body::Body;
use crate::BoxStdError;
/// A boxed [`Body`] trait object. /// A boxed [`Body`] trait object.
pub struct BoxBody { pub struct BoxBody {
// when we've gotten rid of `BoxStdError` we should be able to change the error type to // when we've gotten rid of `BoxStdError` we should be able to change the error type to
@ -78,3 +76,15 @@ where
BoxBody::new(Full::from(s.into())) BoxBody::new(Full::from(s.into()))
} }
} }
// work around for `BoxError` not implementing `std::error::Error`
//
// This is currently required since tower-http's Compression middleware's body type's
// error only implements error when the inner error type does:
// https://github.com/tower-rs/tower-http/blob/master/tower-http/src/lib.rs#L310
//
// Fixing that is a breaking change to tower-http so we should wait a bit, but should
// totally fix it at some point.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct BoxStdError(#[from] pub(crate) tower::BoxError);

View file

@ -1,10 +1,17 @@
use crate::{body::Body, response::IntoResponse}; use crate::{body::Body, response::IntoResponse};
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use http::{header, Response, Request}; use http::{header, Request, Response};
use rejection::{
BodyAlreadyTaken, FailedToBufferBody, InvalidJsonBody, InvalidUtf8, LengthRequired,
MissingExtension, MissingJsonContentType, MissingRouteParams, PayloadTooLarge,
QueryStringMissing,
};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::{collections::HashMap, convert::Infallible, str::FromStr}; use std::{collections::HashMap, convert::Infallible, str::FromStr};
pub mod rejection;
#[async_trait] #[async_trait]
pub trait FromRequest<B>: Sized { pub trait FromRequest<B>: Sized {
type Rejection: IntoResponse<B>; type Rejection: IntoResponse<B>;
@ -24,58 +31,6 @@ where
} }
} }
macro_rules! define_rejection {
(
#[status = $status:ident]
#[body = $body:expr]
pub struct $name:ident (());
) => {
#[derive(Debug)]
pub struct $name(());
impl IntoResponse<Body> for $name {
fn into_response(self) -> http::Response<Body> {
let mut res = http::Response::new(Body::from($body));
*res.status_mut() = http::StatusCode::$status;
res
}
}
};
(
#[status = $status:ident]
#[body = $body:expr]
pub struct $name:ident (BoxError);
) => {
#[derive(Debug)]
pub struct $name(tower::BoxError);
impl $name {
fn from_err<E>(err: E) -> Self
where
E: Into<tower::BoxError>,
{
Self(err.into())
}
}
impl IntoResponse<Body> for $name {
fn into_response(self) -> http::Response<Body> {
let mut res =
http::Response::new(Body::from(format!(concat!($body, ": {}"), self.0)));
*res.status_mut() = http::StatusCode::$status;
res
}
}
};
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Query string was invalid or missing"]
pub struct QueryStringMissing(());
}
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct Query<T>(pub T); pub struct Query<T>(pub T);
@ -96,18 +51,6 @@ where
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct Json<T>(pub T); pub struct Json<T>(pub T);
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Failed to parse the response body as JSON"]
pub struct InvalidJsonBody(BoxError);
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Expected request with `Content-Type: application/json`"]
pub struct MissingJsonContentType(());
}
#[async_trait] #[async_trait]
impl<T> FromRequest<Body> for Json<T> impl<T> FromRequest<Body> for Json<T>
where where
@ -116,7 +59,7 @@ where
type Rejection = Response<Body>; type Rejection = Response<Body>;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
if has_content_type(&req, "application/json") { if has_content_type(req, "application/json") {
let body = take_body(req).map_err(IntoResponse::into_response)?; let body = take_body(req).map_err(IntoResponse::into_response)?;
let bytes = hyper::body::to_bytes(body) let bytes = hyper::body::to_bytes(body)
@ -151,12 +94,6 @@ fn has_content_type<B>(req: &Request<B>, expected_content_type: &str) -> bool {
content_type.starts_with(expected_content_type) content_type.starts_with(expected_content_type)
} }
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Missing request extension"]
pub struct MissingExtension(());
}
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct Extension<T>(pub T); pub struct Extension<T>(pub T);
@ -178,12 +115,6 @@ where
} }
} }
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Failed to buffer the request body"]
pub struct FailedToBufferBody(BoxError);
}
#[async_trait] #[async_trait]
impl FromRequest<Body> for Bytes { impl FromRequest<Body> for Bytes {
type Rejection = Response<Body>; type Rejection = Response<Body>;
@ -200,12 +131,6 @@ impl FromRequest<Body> for Bytes {
} }
} }
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Response body didn't contain valid UTF-8"]
pub struct InvalidUtf8(BoxError);
}
#[async_trait] #[async_trait]
impl FromRequest<Body> for String { impl FromRequest<Body> for String {
type Rejection = Response<Body>; type Rejection = Response<Body>;
@ -236,18 +161,6 @@ impl FromRequest<Body> for Body {
} }
} }
define_rejection! {
#[status = PAYLOAD_TOO_LARGE]
#[body = "Request payload is too large"]
pub struct PayloadTooLarge(());
}
define_rejection! {
#[status = LENGTH_REQUIRED]
#[body = "Content length header is required"]
pub struct LengthRequired(());
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct BytesMaxLength<const N: u64>(pub Bytes); pub struct BytesMaxLength<const N: u64>(pub Bytes);
@ -278,12 +191,6 @@ impl<const N: u64> FromRequest<Body> for BytesMaxLength<N> {
} }
} }
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "No url params found for matched route. This is a bug in tower-web. Please open an issue"]
pub struct MissingRouteParams(());
}
#[derive(Debug)] #[derive(Debug)]
pub struct UrlParamsMap(HashMap<String, String>); pub struct UrlParamsMap(HashMap<String, String>);
@ -394,12 +301,6 @@ macro_rules! impl_parse_url {
impl_parse_url!(T1, T2, T3, T4, T5, T6); impl_parse_url!(T1, T2, T3, T4, T5, T6);
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Cannot have two request body extractors for a single handler"]
pub struct BodyAlreadyTaken(());
}
fn take_body(req: &mut Request<Body>) -> Result<Body, BodyAlreadyTaken> { fn take_body(req: &mut Request<Body>) -> Result<Body, BodyAlreadyTaken> {
struct BodyAlreadyTakenExt; struct BodyAlreadyTakenExt;

108
src/extract/rejection.rs Normal file
View file

@ -0,0 +1,108 @@
use super::IntoResponse;
use crate::body::Body;
macro_rules! define_rejection {
(
#[status = $status:ident]
#[body = $body:expr]
pub struct $name:ident (());
) => {
#[derive(Debug)]
pub struct $name(pub(super) ());
impl IntoResponse<Body> for $name {
fn into_response(self) -> http::Response<Body> {
let mut res = http::Response::new(Body::from($body));
*res.status_mut() = http::StatusCode::$status;
res
}
}
};
(
#[status = $status:ident]
#[body = $body:expr]
pub struct $name:ident (BoxError);
) => {
#[derive(Debug)]
pub struct $name(pub(super) tower::BoxError);
impl $name {
pub(super) fn from_err<E>(err: E) -> Self
where
E: Into<tower::BoxError>,
{
Self(err.into())
}
}
impl IntoResponse<Body> for $name {
fn into_response(self) -> http::Response<Body> {
let mut res =
http::Response::new(Body::from(format!(concat!($body, ": {}"), self.0)));
*res.status_mut() = http::StatusCode::$status;
res
}
}
};
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Query string was invalid or missing"]
pub struct QueryStringMissing(());
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Failed to parse the response body as JSON"]
pub struct InvalidJsonBody(BoxError);
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Expected request with `Content-Type: application/json`"]
pub struct MissingJsonContentType(());
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Missing request extension"]
pub struct MissingExtension(());
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Failed to buffer the request body"]
pub struct FailedToBufferBody(BoxError);
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Response body didn't contain valid UTF-8"]
pub struct InvalidUtf8(BoxError);
}
define_rejection! {
#[status = PAYLOAD_TOO_LARGE]
#[body = "Request payload is too large"]
pub struct PayloadTooLarge(());
}
define_rejection! {
#[status = LENGTH_REQUIRED]
#[body = "Content length header is required"]
pub struct LengthRequired(());
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "No url params found for matched route. This is a bug in tower-web. Please open an issue"]
pub struct MissingRouteParams(());
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Cannot have two request body extractors for a single handler"]
pub struct BodyAlreadyTaken(());
}

View file

@ -1,4 +1,10 @@
use crate::{body::Body, HandleError, extract::FromRequest, response::IntoResponse}; use crate::{
body::Body,
extract::FromRequest,
response::IntoResponse,
routing::{EmptyRouter, MethodFilter, OnMethod},
service::{self, HandleError},
};
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use futures_util::future; use futures_util::future;
@ -9,9 +15,32 @@ use std::{
marker::PhantomData, marker::PhantomData,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tower::{Layer, BoxError, Service, ServiceExt}; use tower::{BoxError, Layer, Service, ServiceExt};
pub fn get<H, B, T>(handler: H) -> OnMethod<IntoService<H, B, T>, EmptyRouter>
where
H: Handler<B, T>,
{
on(MethodFilter::Get, handler)
}
pub fn post<H, B, T>(handler: H) -> OnMethod<IntoService<H, B, T>, EmptyRouter>
where
H: Handler<B, T>,
{
on(MethodFilter::Post, handler)
}
pub fn on<H, B, T>(method: MethodFilter, handler: H) -> OnMethod<IntoService<H, B, T>, EmptyRouter>
where
H: Handler<B, T>,
{
service::on(method, handler.into_service())
}
mod sealed { mod sealed {
#![allow(unreachable_pub)]
pub trait HiddentTrait {} pub trait HiddentTrait {}
pub struct Hidden; pub struct Hidden;
impl HiddentTrait for Hidden {} impl HiddentTrait for Hidden {}
@ -30,9 +59,13 @@ pub trait Handler<B, In>: Sized {
fn layer<L>(self, layer: L) -> Layered<L::Service, In> fn layer<L>(self, layer: L) -> Layered<L::Service, In>
where where
L: Layer<HandlerSvc<Self, B, In>>, L: Layer<IntoService<Self, B, In>>,
{ {
Layered::new(layer.layer(HandlerSvc::new(self))) Layered::new(layer.layer(IntoService::new(self)))
}
fn into_service(self) -> IntoService<Self, B, In> {
IntoService::new(self)
} }
} }
@ -156,33 +189,33 @@ impl<S, T> Layered<S, T> {
} }
} }
pub struct HandlerSvc<H, B, T> { pub struct IntoService<H, B, T> {
handler: H, handler: H,
_input: PhantomData<fn() -> (B, T)>, _marker: PhantomData<fn() -> (B, T)>,
} }
impl<H, B, T> HandlerSvc<H, B, T> { impl<H, B, T> IntoService<H, B, T> {
pub(crate) fn new(handler: H) -> Self { fn new(handler: H) -> Self {
Self { Self {
handler, handler,
_input: PhantomData, _marker: PhantomData,
} }
} }
} }
impl<H, B, T> Clone for HandlerSvc<H, B, T> impl<H, B, T> Clone for IntoService<H, B, T>
where where
H: Clone, H: Clone,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
handler: self.handler.clone(), handler: self.handler.clone(),
_input: PhantomData, _marker: PhantomData,
} }
} }
} }
impl<H, B, T> Service<Request<Body>> for HandlerSvc<H, B, T> impl<H, B, T> Service<Request<Body>> for IntoService<H, B, T>
where where
H: Handler<B, T> + Clone + Send + 'static, H: Handler<B, T> + Clone + Send + 'static,
H::Response: 'static, H::Response: 'static,
@ -192,7 +225,7 @@ where
type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// HandlerSvc can only be constructed from async functions which are always ready, or from // `IntoService` can only be constructed from async functions which are always ready, or from
// `Layered` which bufferes in `<Layered as Handler>::call` and is therefore also always // `Layered` which bufferes in `<Layered as Handler>::call` and is therefore also always
// ready. // ready.
Poll::Ready(Ok(())) Poll::Ready(Ok(()))

View file

@ -1,70 +1,73 @@
// #![doc(html_root_url = "https://docs.rs/tower-http/0.1.0")]
#![warn(
clippy::all,
clippy::dbg_macro,
clippy::todo,
clippy::empty_enum,
clippy::enum_glob_use,
clippy::pub_enum_variant_names,
clippy::mem_forget,
clippy::unused_self,
clippy::filter_map_next,
clippy::needless_continue,
clippy::needless_borrow,
clippy::match_wildcard_for_single_variants,
clippy::if_let_mutex,
clippy::mismatched_target_os,
clippy::await_holding_lock,
clippy::match_on_vec_items,
clippy::imprecise_flops,
clippy::suboptimal_flops,
clippy::lossy_float_literal,
clippy::rest_pat_in_fully_bound_structs,
clippy::fn_params_excessive_bools,
clippy::exit,
clippy::inefficient_to_string,
clippy::linkedlist,
clippy::macro_use_imports,
clippy::option_option,
clippy::verbose_file_reads,
clippy::unnested_or_patterns,
rust_2018_idioms,
future_incompatible,
nonstandard_style,
// missing_docs,
)]
#![deny(unreachable_pub, broken_intra_doc_links, private_in_public)]
#![allow(
elided_lifetimes_in_paths,
// TODO: Remove this once the MSRV bumps to 1.42.0 or above.
clippy::match_like_matches_macro,
clippy::type_complexity
)]
#![forbid(unsafe_code)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(test, allow(clippy::float_cmp))]
use self::body::Body; use self::body::Body;
use body::BoxBody;
use bytes::Bytes; use bytes::Bytes;
use futures_util::ready; use http::{Request, Response};
use handler::HandlerSvc;
use http::{Method, Request, Response};
use pin_project::pin_project;
use response::IntoResponse; use response::IntoResponse;
use routing::{EmptyRouter, OnMethod, Route}; use routing::{EmptyRouter, Route};
use std::{ use std::convert::Infallible;
convert::Infallible, use tower::{BoxError, Service};
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::{util::Oneshot, BoxError, Service, ServiceExt as _};
pub mod body; pub mod body;
pub mod extract; pub mod extract;
pub mod handler; pub mod handler;
pub mod response; pub mod response;
pub mod routing; pub mod routing;
pub mod service;
#[doc(inline)] #[doc(inline)]
pub use self::handler::Handler; pub use self::{
#[doc(inline)] handler::{get, on, post, Handler},
pub use self::routing::AddRoute; routing::AddRoute,
};
pub use async_trait::async_trait; pub use async_trait::async_trait;
pub use tower_http::add_extension::{AddExtension, AddExtensionLayer}; pub use tower_http::add_extension::{AddExtension, AddExtensionLayer};
#[derive(Debug, Copy, Clone)]
pub enum MethodFilter {
Any,
Connect,
Delete,
Get,
Head,
Options,
Patch,
Post,
Put,
Trace,
}
impl MethodFilter {
#[allow(clippy::match_like_matches_macro)]
fn matches(self, method: &Method) -> bool {
use MethodFilter::*;
match (self, method) {
(Any, _)
| (Connect, &Method::CONNECT)
| (Delete, &Method::DELETE)
| (Get, &Method::GET)
| (Head, &Method::HEAD)
| (Options, &Method::OPTIONS)
| (Patch, &Method::PATCH)
| (Post, &Method::POST)
| (Put, &Method::PUT)
| (Trace, &Method::TRACE) => true,
_ => false,
}
}
}
pub fn route<S>(spec: &str, svc: S) -> Route<S, EmptyRouter> pub fn route<S>(spec: &str, svc: S) -> Route<S, EmptyRouter>
where where
S: Service<Request<Body>, Error = Infallible> + Clone, S: Service<Request<Body>, Error = Infallible> + Clone,
@ -72,28 +75,6 @@ where
routing::EmptyRouter.route(spec, svc) routing::EmptyRouter.route(spec, svc)
} }
pub fn get<H, B, T>(handler: H) -> OnMethod<HandlerSvc<H, B, T>, EmptyRouter>
where
H: Handler<B, T>,
{
on_method(MethodFilter::Get, HandlerSvc::new(handler))
}
pub fn post<H, B, T>(handler: H) -> OnMethod<HandlerSvc<H, B, T>, EmptyRouter>
where
H: Handler<B, T>,
{
on_method(MethodFilter::Post, HandlerSvc::new(handler))
}
pub fn on_method<S>(method: MethodFilter, svc: S) -> OnMethod<S, EmptyRouter> {
OnMethod {
method,
svc,
fallback: EmptyRouter,
}
}
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
@ -110,20 +91,8 @@ impl<T> ResultExt<T> for Result<T, Infallible> {
} }
} }
// work around for `BoxError` not implementing `std::error::Error`
//
// This is currently required since tower-http's Compression middleware's body type's
// error only implements error when the inner error type does:
// https://github.com/tower-rs/tower-http/blob/master/tower-http/src/lib.rs#L310
//
// Fixing that is a breaking change to tower-http so we should wait a bit, but should
// totally fix it at some point.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct BoxStdError(#[from] pub(crate) tower::BoxError);
pub trait ServiceExt<B>: Service<Request<Body>, Response = Response<B>> { pub trait ServiceExt<B>: Service<Request<Body>, Response = Response<B>> {
fn handle_error<F, Res>(self, f: F) -> HandleError<Self, F> fn handle_error<F, Res>(self, f: F) -> service::HandleError<Self, F>
where where
Self: Sized, Self: Sized,
F: FnOnce(Self::Error) -> Res, F: FnOnce(Self::Error) -> Res,
@ -131,87 +100,8 @@ pub trait ServiceExt<B>: Service<Request<Body>, Response = Response<B>> {
B: http_body::Body<Data = Bytes> + Send + Sync + 'static, B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static, B::Error: Into<BoxError> + Send + Sync + 'static,
{ {
HandleError::new(self, f) service::HandleError::new(self, f)
} }
} }
impl<S, B> ServiceExt<B> for S where S: Service<Request<Body>, Response = Response<B>> {} impl<S, B> ServiceExt<B> for S where S: Service<Request<Body>, Response = Response<B>> {}
#[derive(Clone)]
pub struct HandleError<S, F> {
inner: S,
f: F,
}
impl<S, F> HandleError<S, F> {
pub(crate) fn new(inner: S, f: F) -> Self {
Self { inner, f }
}
}
impl<S, F> fmt::Debug for HandleError<S, F>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandleError")
.field("inner", &self.inner)
.field("f", &format_args!("{}", std::any::type_name::<F>()))
.finish()
}
}
impl<S, F, B, Res> Service<Request<Body>> for HandleError<S, F>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
F: FnOnce(S::Error) -> Res + Clone,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = HandleErrorFuture<Oneshot<S, Request<Body>>, F>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
HandleErrorFuture {
f: Some(self.f.clone()),
inner: self.inner.clone().oneshot(req),
}
}
}
#[pin_project]
pub struct HandleErrorFuture<Fut, F> {
#[pin]
inner: Fut,
f: Option<F>,
}
impl<Fut, F, E, B, Res> Future for HandleErrorFuture<Fut, F>
where
Fut: Future<Output = Result<Response<B>, E>>,
F: FnOnce(E) -> Res,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match ready!(this.inner.poll(cx)) {
Ok(res) => Ok(res.map(BoxBody::new)).into(),
Err(err) => {
let f = this.f.take().unwrap();
let res = f(err).into_response();
Ok(res.map(BoxBody::new)).into()
}
}
}
}

View file

@ -122,47 +122,6 @@ impl IntoResponse<Body> for std::borrow::Cow<'static, [u8]> {
} }
} }
pub struct Json<T>(pub T);
impl<T> IntoResponse<Body> for Json<T>
where
T: Serialize,
{
fn into_response(self) -> Response<Body> {
let bytes = match serde_json::to_vec(&self.0) {
Ok(res) => res,
Err(err) => {
return Response::builder()
.header(header::CONTENT_TYPE, "text/plain")
.body(Body::from(err.to_string()))
.unwrap();
}
};
let mut res = Response::new(Body::from(bytes));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
res
}
}
pub struct Html<T>(pub T);
impl<T> IntoResponse<Body> for Html<T>
where
T: Into<Bytes>,
{
fn into_response(self) -> Response<Body> {
let bytes = self.0.into();
let mut res = Response::new(Body::from(bytes));
res.headers_mut()
.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/html"));
res
}
}
impl<B> IntoResponse<B> for StatusCode impl<B> IntoResponse<B> for StatusCode
where where
B: Default, B: Default,
@ -195,3 +154,57 @@ where
res res
} }
} }
pub struct Html<T>(pub T);
impl<T> IntoResponse<Body> for Html<T>
where
T: Into<Body>,
{
fn into_response(self) -> Response<Body> {
let mut res = Response::new(self.0.into());
res.headers_mut()
.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/html"));
res
}
}
pub struct Json<T>(pub T);
impl<T> IntoResponse<Body> for Json<T>
where
T: Serialize,
{
fn into_response(self) -> Response<Body> {
let bytes = match serde_json::to_vec(&self.0) {
Ok(res) => res,
Err(err) => {
return Response::builder()
.header(header::CONTENT_TYPE, "text/plain")
.body(Body::from(err.to_string()))
.unwrap();
}
};
let mut res = Response::new(Body::from(bytes));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
res
}
}
pub struct Text<T>(pub T);
impl<T> IntoResponse<Body> for Text<T>
where
T: Into<Body>,
{
fn into_response(self) -> Response<Body> {
let mut res = Response::new(self.0.into());
res.headers_mut()
.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"));
res
}
}

View file

@ -1,12 +1,12 @@
use crate::{ use crate::{
body::BoxBody, body::BoxBody,
handler::{Handler, HandlerSvc}, handler::{self, Handler},
response::IntoResponse, response::IntoResponse,
MethodFilter, ResultExt, ResultExt,
}; };
use bytes::Bytes; use bytes::Bytes;
use futures_util::{future, ready}; use futures_util::{future, ready};
use http::{Request, Response, StatusCode}; use http::{Method, Request, Response, StatusCode};
use hyper::Body; use hyper::Body;
use itertools::Itertools; use itertools::Itertools;
use pin_project::pin_project; use pin_project::pin_project;
@ -27,6 +27,39 @@ use tower::{
// ===== DSL ===== // ===== DSL =====
#[derive(Debug, Copy, Clone)]
pub enum MethodFilter {
Any,
Connect,
Delete,
Get,
Head,
Options,
Patch,
Post,
Put,
Trace,
}
impl MethodFilter {
#[allow(clippy::match_like_matches_macro)]
fn matches(self, method: &Method) -> bool {
match (self, method) {
(MethodFilter::Any, _)
| (MethodFilter::Connect, &Method::CONNECT)
| (MethodFilter::Delete, &Method::DELETE)
| (MethodFilter::Get, &Method::GET)
| (MethodFilter::Head, &Method::HEAD)
| (MethodFilter::Options, &Method::OPTIONS)
| (MethodFilter::Patch, &Method::PATCH)
| (MethodFilter::Post, &Method::POST)
| (MethodFilter::Put, &Method::PUT)
| (MethodFilter::Trace, &Method::TRACE) => true,
_ => false,
}
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct Route<S, F> { pub struct Route<S, F> {
pub(crate) pattern: PathPattern, pub(crate) pattern: PathPattern,
@ -84,21 +117,21 @@ impl<S, F> AddRoute for Route<S, F> {
} }
impl<S, F> OnMethod<S, F> { impl<S, F> OnMethod<S, F> {
pub fn get<H, B, T>(self, handler: H) -> OnMethod<HandlerSvc<H, B, T>, Self> pub fn get<H, B, T>(self, handler: H) -> OnMethod<handler::IntoService<H, B, T>, Self>
where where
H: Handler<B, T>, H: Handler<B, T>,
{ {
self.with_method(MethodFilter::Get, HandlerSvc::new(handler)) self.on_method(MethodFilter::Get, handler.into_service())
} }
pub fn post<H, B, T>(self, handler: H) -> OnMethod<HandlerSvc<H, B, T>, Self> pub fn post<H, B, T>(self, handler: H) -> OnMethod<handler::IntoService<H, B, T>, Self>
where where
H: Handler<B, T>, H: Handler<B, T>,
{ {
self.with_method(MethodFilter::Post, HandlerSvc::new(handler)) self.on_method(MethodFilter::Post, handler.into_service())
} }
pub fn with_method<T>(self, method: MethodFilter, svc: T) -> OnMethod<T, Self> { pub fn on_method<T>(self, method: MethodFilter, svc: T) -> OnMethod<T, Self> {
OnMethod { OnMethod {
method, method,
svc, svc,
@ -551,7 +584,7 @@ mod tests {
fn assert_match(route_spec: &'static str, path: &'static str) { fn assert_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec); let route = PathPattern::new(route_spec);
assert!( assert!(
route.matches(&path).is_some(), route.matches(path).is_some(),
"`{}` doesn't match `{}`", "`{}` doesn't match `{}`",
path, path,
route_spec route_spec
@ -561,7 +594,7 @@ mod tests {
fn refute_match(route_spec: &'static str, path: &'static str) { fn refute_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec); let route = PathPattern::new(route_spec);
assert!( assert!(
route.matches(&path).is_none(), route.matches(path).is_none(),
"`{}` did match `{}` (but shouldn't)", "`{}` did match `{}` (but shouldn't)",
path, path,
route_spec route_spec

112
src/service.rs Normal file
View file

@ -0,0 +1,112 @@
use crate::{
body::{Body, BoxBody},
response::IntoResponse,
routing::{EmptyRouter, MethodFilter, OnMethod},
};
use bytes::Bytes;
use futures_util::ready;
use http::{Request, Response};
use pin_project::pin_project;
use std::{
convert::Infallible,
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::{util::Oneshot, BoxError, Service, ServiceExt as _};
pub fn get<S>(svc: S) -> OnMethod<S, EmptyRouter> {
on(MethodFilter::Get, svc)
}
pub fn post<S>(svc: S) -> OnMethod<S, EmptyRouter> {
on(MethodFilter::Post, svc)
}
pub fn on<S>(method: MethodFilter, svc: S) -> OnMethod<S, EmptyRouter> {
OnMethod {
method,
svc,
fallback: EmptyRouter,
}
}
#[derive(Clone)]
pub struct HandleError<S, F> {
inner: S,
f: F,
}
impl<S, F> HandleError<S, F> {
pub(crate) fn new(inner: S, f: F) -> Self {
Self { inner, f }
}
}
impl<S, F> fmt::Debug for HandleError<S, F>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandleError")
.field("inner", &self.inner)
.field("f", &format_args!("{}", std::any::type_name::<F>()))
.finish()
}
}
impl<S, F, B, Res> Service<Request<Body>> for HandleError<S, F>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
F: FnOnce(S::Error) -> Res + Clone,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = HandleErrorFuture<Oneshot<S, Request<Body>>, F>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
HandleErrorFuture {
f: Some(self.f.clone()),
inner: self.inner.clone().oneshot(req),
}
}
}
#[pin_project]
pub struct HandleErrorFuture<Fut, F> {
#[pin]
inner: Fut,
f: Option<F>,
}
impl<Fut, F, E, B, Res> Future for HandleErrorFuture<Fut, F>
where
Fut: Future<Output = Result<Response<B>, E>>,
F: FnOnce(E) -> Res,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match ready!(this.inner.poll(cx)) {
Ok(res) => Ok(res.map(BoxBody::new)).into(),
Err(err) => {
let f = this.f.take().unwrap();
let res = f(err).into_response();
Ok(res.map(BoxBody::new)).into()
}
}
}
}

View file

@ -1,4 +1,4 @@
use crate::{extract, get, on_method, post, route, AddRoute, Handler, MethodFilter}; use crate::{extract, get, post, route, routing::MethodFilter, service, AddRoute, Handler};
use http::{Request, Response, StatusCode}; use http::{Request, Response, StatusCode};
use hyper::{Body, Server}; use hyper::{Body, Server};
use serde::Deserialize; use serde::Deserialize;
@ -307,7 +307,7 @@ async fn service_handlers() {
let app = route( let app = route(
"/echo", "/echo",
on_method( service::on(
MethodFilter::Post, MethodFilter::Post,
service_fn(|req: Request<Body>| async move { service_fn(|req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(req.into_body())) Ok::<_, Infallible>(Response::new(req.into_body()))
@ -316,7 +316,7 @@ async fn service_handlers() {
) )
.route( .route(
"/static/Cargo.toml", "/static/Cargo.toml",
on_method( service::on(
MethodFilter::Get, MethodFilter::Get,
ServeFile::new("Cargo.toml").handle_error(|error: std::io::Error| { ServeFile::new("Cargo.toml").handle_error(|error: std::io::Error| {
(StatusCode::INTERNAL_SERVER_ERROR, error.to_string()) (StatusCode::INTERNAL_SERVER_ERROR, error.to_string())
@ -564,7 +564,7 @@ async fn layer_on_whole_router() {
// // TODO(david): composing two apps that have had layers applied // // TODO(david): composing two apps that have had layers applied
/// Run a `tower::Service` in the background and get a URI for it. /// Run a `tower::Service` in the background and get a URI for it.
pub async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
where where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static, S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
ResBody: http_body::Body + Send + 'static, ResBody: http_body::Body + Send + 'static,