Quality of life improvements

This commit is contained in:
David Pedersen 2021-06-01 14:52:18 +02:00
parent 90b9dffce7
commit ea582ab8d9
6 changed files with 42 additions and 125 deletions

View file

@ -17,6 +17,7 @@ serde_json = "1.0"
serde_urlencoded = "0.7" serde_urlencoded = "0.7"
thiserror = "1.0" thiserror = "1.0"
tower = { version = "0.4", features = ["util", "buffer"] } tower = { version = "0.4", features = ["util", "buffer"] }
tower-http = { version = "0.1", features = ["add-extension"] }
[dev-dependencies] [dev-dependencies]
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }

View file

@ -51,10 +51,10 @@ async fn get(
params: extract::UrlParams<(String,)>, params: extract::UrlParams<(String,)>,
state: extract::Extension<SharedState>, state: extract::Extension<SharedState>,
) -> Result<Bytes, StatusCode> { ) -> Result<Bytes, StatusCode> {
let state = state.into_inner(); let state = state.0;
let db = &state.lock().unwrap().db; let db = &state.lock().unwrap().db;
let key = params.into_inner(); let (key,) = params.0;
if let Some(value) = db.get(&key) { if let Some(value) = db.get(&key) {
Ok(value.clone()) Ok(value.clone())
@ -69,11 +69,11 @@ async fn set(
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
state: extract::Extension<SharedState>, state: extract::Extension<SharedState>,
) { ) {
let state = state.into_inner(); let state = state.0;
let db = &mut state.lock().unwrap().db; let db = &mut state.lock().unwrap().db;
let key = params.into_inner(); let (key,) = params.0;
let value = value.into_inner(); let value = value.0;
db.insert(key, value); db.insert(key, value);
} }

View file

@ -1,10 +1,7 @@
use crate::{ use crate::{body::Body, response::IntoResponse};
body::Body,
response::{BoxIntoResponse, IntoResponse},
};
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use http::{header, Request}; use http::{header, Response, Request};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::{collections::HashMap, convert::Infallible, str::FromStr}; use std::{collections::HashMap, convert::Infallible, str::FromStr};
@ -80,13 +77,7 @@ define_rejection! {
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct Query<T>(T); pub struct Query<T>(pub T);
impl<T> Query<T> {
pub fn into_inner(self) -> T {
self.0
}
}
#[async_trait] #[async_trait]
impl<T> FromRequest<Body> for Query<T> impl<T> FromRequest<Body> for Query<T>
@ -103,13 +94,7 @@ where
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct Json<T>(T); pub struct Json<T>(pub T);
impl<T> Json<T> {
pub fn into_inner(self) -> T {
self.0
}
}
define_rejection! { define_rejection! {
#[status = BAD_REQUEST] #[status = BAD_REQUEST]
@ -128,24 +113,24 @@ impl<T> FromRequest<Body> for Json<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
{ {
type Rejection = BoxIntoResponse<Body>; type Rejection = Response<Body>;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
if has_content_type(&req, "application/json") { if has_content_type(&req, "application/json") {
let body = take_body(req).map_err(IntoResponse::boxed)?; let body = take_body(req).map_err(IntoResponse::into_response)?;
let bytes = hyper::body::to_bytes(body) let bytes = hyper::body::to_bytes(body)
.await .await
.map_err(InvalidJsonBody::from_err) .map_err(InvalidJsonBody::from_err)
.map_err(IntoResponse::boxed)?; .map_err(IntoResponse::into_response)?;
let value = serde_json::from_slice(&bytes) let value = serde_json::from_slice(&bytes)
.map_err(InvalidJsonBody::from_err) .map_err(InvalidJsonBody::from_err)
.map_err(IntoResponse::boxed)?; .map_err(IntoResponse::into_response)?;
Ok(Json(value)) Ok(Json(value))
} else { } else {
Err(MissingJsonContentType(()).boxed()) Err(MissingJsonContentType(()).into_response())
} }
} }
} }
@ -173,13 +158,7 @@ define_rejection! {
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct Extension<T>(T); pub struct Extension<T>(pub T);
impl<T> Extension<T> {
pub fn into_inner(self) -> T {
self.0
}
}
#[async_trait] #[async_trait]
impl<T> FromRequest<Body> for Extension<T> impl<T> FromRequest<Body> for Extension<T>
@ -207,15 +186,15 @@ define_rejection! {
#[async_trait] #[async_trait]
impl FromRequest<Body> for Bytes { impl FromRequest<Body> for Bytes {
type Rejection = BoxIntoResponse<Body>; type Rejection = Response<Body>;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
let body = take_body(req).map_err(IntoResponse::boxed)?; let body = take_body(req).map_err(IntoResponse::into_response)?;
let bytes = hyper::body::to_bytes(body) let bytes = hyper::body::to_bytes(body)
.await .await
.map_err(FailedToBufferBody::from_err) .map_err(FailedToBufferBody::from_err)
.map_err(IntoResponse::boxed)?; .map_err(IntoResponse::into_response)?;
Ok(bytes) Ok(bytes)
} }
@ -229,20 +208,20 @@ define_rejection! {
#[async_trait] #[async_trait]
impl FromRequest<Body> for String { impl FromRequest<Body> for String {
type Rejection = BoxIntoResponse<Body>; type Rejection = Response<Body>;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
let body = take_body(req).map_err(IntoResponse::boxed)?; let body = take_body(req).map_err(IntoResponse::into_response)?;
let bytes = hyper::body::to_bytes(body) let bytes = hyper::body::to_bytes(body)
.await .await
.map_err(FailedToBufferBody::from_err) .map_err(FailedToBufferBody::from_err)
.map_err(IntoResponse::boxed)? .map_err(IntoResponse::into_response)?
.to_vec(); .to_vec();
let string = String::from_utf8(bytes) let string = String::from_utf8(bytes)
.map_err(InvalidUtf8::from_err) .map_err(InvalidUtf8::from_err)
.map_err(IntoResponse::boxed)?; .map_err(IntoResponse::into_response)?;
Ok(string) Ok(string)
} }
@ -270,36 +249,30 @@ define_rejection! {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct BytesMaxLength<const N: u64>(Bytes); pub struct BytesMaxLength<const N: u64>(pub Bytes);
impl<const N: u64> BytesMaxLength<N> {
pub fn into_inner(self) -> Bytes {
self.0
}
}
#[async_trait] #[async_trait]
impl<const N: u64> FromRequest<Body> for BytesMaxLength<N> { impl<const N: u64> FromRequest<Body> for BytesMaxLength<N> {
type Rejection = BoxIntoResponse<Body>; type Rejection = Response<Body>;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned(); let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
let body = take_body(req).map_err(|reject| reject.boxed())?; let body = take_body(req).map_err(|reject| reject.into_response())?;
let content_length = let content_length =
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok()); content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
if let Some(length) = content_length { if let Some(length) = content_length {
if length > N { if length > N {
return Err(PayloadTooLarge(()).boxed()); return Err(PayloadTooLarge(()).into_response());
} }
} else { } else {
return Err(LengthRequired(()).boxed()); return Err(LengthRequired(()).into_response());
}; };
let bytes = hyper::body::to_bytes(body) let bytes = hyper::body::to_bytes(body)
.await .await
.map_err(|e| FailedToBufferBody::from_err(e).boxed())?; .map_err(|e| FailedToBufferBody::from_err(e).into_response())?;
Ok(BytesMaxLength(bytes)) Ok(BytesMaxLength(bytes))
} }
@ -367,7 +340,7 @@ impl IntoResponse<Body> for InvalidUrlParam {
} }
} }
pub struct UrlParams<T>(T); pub struct UrlParams<T>(pub T);
macro_rules! impl_parse_url { macro_rules! impl_parse_url {
() => {}; () => {};
@ -379,7 +352,7 @@ macro_rules! impl_parse_url {
$head: FromStr + Send, $head: FromStr + Send,
$( $tail: FromStr + Send, )* $( $tail: FromStr + Send, )*
{ {
type Rejection = BoxIntoResponse<Body>; type Rejection = Response<Body>;
#[allow(non_snake_case)] #[allow(non_snake_case)]
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
@ -389,27 +362,27 @@ macro_rules! impl_parse_url {
{ {
params.take().expect("params already taken").0 params.take().expect("params already taken").0
} else { } else {
return Err(MissingRouteParams(()).boxed()) return Err(MissingRouteParams(()).into_response())
}; };
if let [(_, $head), $((_, $tail),)*] = &*params { if let [(_, $head), $((_, $tail),)*] = &*params {
let $head = if let Ok(x) = $head.parse::<$head>() { let $head = if let Ok(x) = $head.parse::<$head>() {
x x
} else { } else {
return Err(InvalidUrlParam::new::<$head>().boxed()); return Err(InvalidUrlParam::new::<$head>().into_response());
}; };
$( $(
let $tail = if let Ok(x) = $tail.parse::<$tail>() { let $tail = if let Ok(x) = $tail.parse::<$tail>() {
x x
} else { } else {
return Err(InvalidUrlParam::new::<$tail>().boxed()); return Err(InvalidUrlParam::new::<$tail>().into_response());
}; };
)* )*
Ok(UrlParams(($head, $($tail,)*))) Ok(UrlParams(($head, $($tail,)*)))
} else { } else {
return Err(MissingRouteParams(()).boxed()) return Err(MissingRouteParams(()).into_response())
} }
} }
} }
@ -420,49 +393,6 @@ macro_rules! impl_parse_url {
impl_parse_url!(T1, T2, T3, T4, T5, T6); impl_parse_url!(T1, T2, T3, T4, T5, T6);
impl<T1> UrlParams<(T1,)> {
pub fn into_inner(self) -> T1 {
(self.0).0
}
}
impl<T1, T2> UrlParams<(T1, T2)> {
pub fn into_inner(self) -> (T1, T2) {
((self.0).0, (self.0).1)
}
}
impl<T1, T2, T3> UrlParams<(T1, T2, T3)> {
pub fn into_inner(self) -> (T1, T2, T3) {
((self.0).0, (self.0).1, (self.0).2)
}
}
impl<T1, T2, T3, T4> UrlParams<(T1, T2, T3, T4)> {
pub fn into_inner(self) -> (T1, T2, T3, T4) {
((self.0).0, (self.0).1, (self.0).2, (self.0).3)
}
}
impl<T1, T2, T3, T4, T5> UrlParams<(T1, T2, T3, T4, T5)> {
pub fn into_inner(self) -> (T1, T2, T3, T4, T5) {
((self.0).0, (self.0).1, (self.0).2, (self.0).3, (self.0).4)
}
}
impl<T1, T2, T3, T4, T5, T6> UrlParams<(T1, T2, T3, T4, T5, T6)> {
pub fn into_inner(self) -> (T1, T2, T3, T4, T5, T6) {
(
(self.0).0,
(self.0).1,
(self.0).2,
(self.0).3,
(self.0).4,
(self.0).5,
)
}
}
define_rejection! { define_rejection! {
#[status = INTERNAL_SERVER_ERROR] #[status = INTERNAL_SERVER_ERROR]
#[body = "Cannot have two request body extractors for a single handler"] #[body = "Cannot have two request body extractors for a single handler"]

View file

@ -22,6 +22,9 @@ pub mod handler;
pub mod response; pub mod response;
pub mod routing; pub mod routing;
pub use tower_http::add_extension::{AddExtension, AddExtensionLayer};
pub use async_trait::async_trait;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;

View file

@ -5,17 +5,9 @@ use serde::Serialize;
use std::convert::Infallible; use std::convert::Infallible;
use tower::util::Either; use tower::util::Either;
// TODO(david): can we change this to not be generic over the body and just use hyper::Body?
pub trait IntoResponse<B> { pub trait IntoResponse<B> {
fn into_response(self) -> Response<B>; fn into_response(self) -> Response<B>;
// TODO(david): remove this an return return `Response<B>` instead. That is what this method
// does anyway.
fn boxed(self) -> BoxIntoResponse<B>
where
Self: Sized + 'static,
{
BoxIntoResponse(self.into_response())
}
} }
impl<B> IntoResponse<B> for () impl<B> IntoResponse<B> for ()
@ -171,14 +163,6 @@ where
} }
} }
pub struct BoxIntoResponse<B>(Response<B>);
impl<B> IntoResponse<B> for BoxIntoResponse<B> {
fn into_response(self) -> Response<B> {
self.0
}
}
impl<B> IntoResponse<B> for StatusCode impl<B> IntoResponse<B> for StatusCode
where where
B: Default, B: Default,

View file

@ -51,7 +51,7 @@ async fn deserialize_body() {
let app = app() let app = app()
.at("/") .at("/")
.post(|_: Request<Body>, input: extract::Json<Input>| async { input.into_inner().foo }) .post(|_: Request<Body>, input: extract::Json<Input>| async { input.0.foo })
.into_service(); .into_service();
let addr = run_in_background(app).await; let addr = run_in_background(app).await;
@ -78,8 +78,7 @@ async fn consume_body_to_json_requires_json_content_type() {
let app = app() let app = app()
.at("/") .at("/")
.post(|_: Request<Body>, input: extract::Json<Input>| async { .post(|_: Request<Body>, input: extract::Json<Input>| async {
let input = input.into_inner(); input.0.foo
input.foo
}) })
.into_service(); .into_service();
@ -216,7 +215,7 @@ async fn extracting_url_params() {
.at("/users/:id") .at("/users/:id")
.get( .get(
|_: Request<Body>, params: extract::UrlParams<(i32,)>| async move { |_: Request<Body>, params: extract::UrlParams<(i32,)>| async move {
let id = params.into_inner(); let (id,) = params.0;
assert_eq!(id, 42); assert_eq!(id, 42);
}, },
) )