mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-28 07:20:12 +01:00
This changes everything
This commit is contained in:
parent
19cbece1dc
commit
18f613ff98
5 changed files with 424 additions and 146 deletions
|
@ -1,7 +1,7 @@
|
|||
#![allow(warnings)]
|
||||
|
||||
use bytes::Bytes;
|
||||
use http::{Request, StatusCode};
|
||||
use http::{Request, Response, StatusCode};
|
||||
use hyper::Server;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
|
@ -14,7 +14,12 @@ use tower::{make::Shared, ServiceBuilder};
|
|||
use tower_http::{
|
||||
add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer,
|
||||
};
|
||||
use tower_web::{body::Body, extract, response, Error};
|
||||
use tower_web::{
|
||||
body::Body,
|
||||
extract,
|
||||
response::{self, IntoResponse},
|
||||
Error,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
|
@ -54,16 +59,16 @@ async fn get(
|
|||
_req: Request<Body>,
|
||||
params: extract::UrlParams<(String,)>,
|
||||
state: extract::Extension<SharedState>,
|
||||
) -> Result<Bytes, Error> {
|
||||
) -> Result<Bytes, NotFound> {
|
||||
let state = state.into_inner();
|
||||
let db = &state.lock().unwrap().db;
|
||||
|
||||
let (key,) = params.into_inner();
|
||||
let key = params.into_inner();
|
||||
|
||||
if let Some(value) = db.get(&key) {
|
||||
Ok(value.clone())
|
||||
} else {
|
||||
Err(Error::Status(StatusCode::NOT_FOUND))
|
||||
Err(NotFound)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,14 +77,23 @@ async fn set(
|
|||
params: extract::UrlParams<(String,)>,
|
||||
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
|
||||
state: extract::Extension<SharedState>,
|
||||
) -> response::Empty {
|
||||
) {
|
||||
let state = state.into_inner();
|
||||
let db = &mut state.lock().unwrap().db;
|
||||
|
||||
let (key,) = params.into_inner();
|
||||
let key = params.into_inner();
|
||||
let value = value.into_inner();
|
||||
|
||||
db.insert(key.to_string(), value);
|
||||
|
||||
response::Empty
|
||||
}
|
||||
|
||||
struct NotFound;
|
||||
|
||||
impl IntoResponse<Body> for NotFound {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
|
335
src/extract.rs
335
src/extract.rs
|
@ -1,36 +1,85 @@
|
|||
use crate::{body::Body, Error};
|
||||
use crate::{
|
||||
body::Body,
|
||||
response::{BoxIntoResponse, IntoResponse},
|
||||
Error,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use http::{header, Request, StatusCode};
|
||||
use http::{header, Request};
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{collections::HashMap, str::FromStr};
|
||||
use std::{collections::HashMap, convert::Infallible, str::FromStr};
|
||||
|
||||
#[async_trait]
|
||||
pub trait FromRequest: Sized {
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error>;
|
||||
}
|
||||
pub trait FromRequest<B>: Sized {
|
||||
type Rejection: IntoResponse<B>;
|
||||
|
||||
fn take_body(req: &mut Request<Body>) -> Body {
|
||||
struct BodyAlreadyTaken;
|
||||
|
||||
if req.extensions_mut().insert(BodyAlreadyTaken).is_some() {
|
||||
panic!("Cannot have two request body on extractors")
|
||||
} else {
|
||||
let body = std::mem::take(req.body_mut());
|
||||
body
|
||||
}
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> FromRequest for Option<T>
|
||||
impl<T, B> FromRequest<B> for Option<T>
|
||||
where
|
||||
T: FromRequest,
|
||||
T: FromRequest<B>,
|
||||
{
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Option<T>, Error> {
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Option<T>, Self::Rejection> {
|
||||
Ok(T::from_request(req).await.ok())
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! define_rejection {
|
||||
(
|
||||
#[status = $status:ident]
|
||||
#[body = $body:expr]
|
||||
pub struct $name:ident (());
|
||||
) => {
|
||||
#[derive(Debug)]
|
||||
pub struct $name(());
|
||||
|
||||
impl IntoResponse<Body> for $name {
|
||||
fn into_response(self) -> http::Response<Body> {
|
||||
let mut res = http::Response::new(Body::from($body));
|
||||
*res.status_mut() = http::StatusCode::$status;
|
||||
res
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
#[status = $status:ident]
|
||||
#[body = $body:expr]
|
||||
pub struct $name:ident (BoxError);
|
||||
) => {
|
||||
#[derive(Debug)]
|
||||
pub struct $name(tower::BoxError);
|
||||
|
||||
impl $name {
|
||||
fn from_err<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<tower::BoxError>,
|
||||
{
|
||||
Self(err.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for $name {
|
||||
fn into_response(self) -> http::Response<Body> {
|
||||
let mut res =
|
||||
http::Response::new(Body::from(format!(concat!($body, ": {}"), self.0)));
|
||||
*res.status_mut() = http::StatusCode::$status;
|
||||
res
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "Query string was invalid or missing"]
|
||||
pub struct QueryStringMissing(());
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Query<T>(T);
|
||||
|
||||
|
@ -41,13 +90,15 @@ impl<T> Query<T> {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> FromRequest for Query<T>
|
||||
impl<T> FromRequest<Body> for Query<T>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||
let query = req.uri().query().ok_or(Error::QueryStringMissing)?;
|
||||
let value = serde_urlencoded::from_str(query).map_err(Error::DeserializeQueryString)?;
|
||||
type Rejection = QueryStringMissing;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
let query = req.uri().query().ok_or(QueryStringMissing(()))?;
|
||||
let value = serde_urlencoded::from_str(query).map_err(|_| QueryStringMissing(()))?;
|
||||
Ok(Query(value))
|
||||
}
|
||||
}
|
||||
|
@ -61,22 +112,41 @@ impl<T> Json<T> {
|
|||
}
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "Failed to parse the response body as JSON"]
|
||||
pub struct InvalidJsonBody(BoxError);
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "Expected request with `Content-Type: application/json`"]
|
||||
pub struct MissingJsonContentType(());
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> FromRequest for Json<T>
|
||||
impl<T> FromRequest<Body> for Json<T>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||
type Rejection = BoxIntoResponse<Body>;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
if has_content_type(&req, "application/json") {
|
||||
let body = take_body(req);
|
||||
let body = take_body(req).map_err(IntoResponse::boxed)?;
|
||||
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeRequestBody)?;
|
||||
let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?;
|
||||
.map_err(InvalidJsonBody::from_err)
|
||||
.map_err(IntoResponse::boxed)?;
|
||||
|
||||
let value = serde_json::from_slice(&bytes)
|
||||
.map_err(InvalidJsonBody::from_err)
|
||||
.map_err(IntoResponse::boxed)?;
|
||||
|
||||
Ok(Json(value))
|
||||
} else {
|
||||
Err(Error::Status(StatusCode::BAD_REQUEST))
|
||||
Err(MissingJsonContentType(()).boxed())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -97,6 +167,12 @@ fn has_content_type<B>(req: &Request<B>, expected_content_type: &str) -> bool {
|
|||
content_type.starts_with(expected_content_type)
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = INTERNAL_SERVER_ERROR]
|
||||
#[body = "Missing request extension"]
|
||||
pub struct MissingExtension(());
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Extension<T>(T);
|
||||
|
||||
|
@ -107,60 +183,93 @@ impl<T> Extension<T> {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> FromRequest for Extension<T>
|
||||
impl<T> FromRequest<Body> for Extension<T>
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||
type Rejection = MissingExtension;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
let value = req
|
||||
.extensions()
|
||||
.get::<T>()
|
||||
.ok_or_else(|| Error::MissingExtension {
|
||||
type_name: std::any::type_name::<T>(),
|
||||
})
|
||||
.ok_or(MissingExtension(()))
|
||||
.map(|x| x.clone())?;
|
||||
|
||||
Ok(Extension(value))
|
||||
}
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "Failed to buffer the request body"]
|
||||
pub struct FailedToBufferBody(BoxError);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequest for Bytes {
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||
let body = take_body(req);
|
||||
impl FromRequest<Body> for Bytes {
|
||||
type Rejection = BoxIntoResponse<Body>;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
let body = take_body(req).map_err(IntoResponse::boxed)?;
|
||||
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeRequestBody)?;
|
||||
.map_err(FailedToBufferBody::from_err)
|
||||
.map_err(IntoResponse::boxed)?;
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "Response body didn't contain valid UTF-8"]
|
||||
pub struct InvalidUtf8(BoxError);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequest for String {
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||
let body = take_body(req);
|
||||
impl FromRequest<Body> for String {
|
||||
type Rejection = BoxIntoResponse<Body>;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
let body = take_body(req).map_err(IntoResponse::boxed)?;
|
||||
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeRequestBody)?
|
||||
.map_err(FailedToBufferBody::from_err)
|
||||
.map_err(IntoResponse::boxed)?
|
||||
.to_vec();
|
||||
|
||||
let string = String::from_utf8(bytes).map_err(|_| Error::InvalidUtf8)?;
|
||||
let string = String::from_utf8(bytes)
|
||||
.map_err(InvalidUtf8::from_err)
|
||||
.map_err(IntoResponse::boxed)?;
|
||||
|
||||
Ok(string)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequest for Body {
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||
let body = take_body(req);
|
||||
Ok(body)
|
||||
impl FromRequest<Body> for Body {
|
||||
type Rejection = BodyAlreadyTaken;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
take_body(req)
|
||||
}
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = PAYLOAD_TOO_LARGE]
|
||||
#[body = "Request payload is too large"]
|
||||
pub struct PayloadTooLarge(());
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = LENGTH_REQUIRED]
|
||||
#[body = "Content length header is required"]
|
||||
pub struct LengthRequired(());
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BytesMaxLength<const N: u64>(Bytes);
|
||||
|
||||
|
@ -171,30 +280,38 @@ impl<const N: u64> BytesMaxLength<N> {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<const N: u64> FromRequest for BytesMaxLength<N> {
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||
impl<const N: u64> FromRequest<Body> for BytesMaxLength<N> {
|
||||
type Rejection = BoxIntoResponse<Body>;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
|
||||
let body = take_body(req);
|
||||
let body = take_body(req).map_err(|reject| reject.boxed())?;
|
||||
|
||||
let content_length =
|
||||
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
|
||||
|
||||
if let Some(length) = content_length {
|
||||
if length > N {
|
||||
return Err(Error::PayloadTooLarge);
|
||||
return Err(PayloadTooLarge(()).boxed());
|
||||
}
|
||||
} else {
|
||||
return Err(Error::LengthRequired);
|
||||
return Err(LengthRequired(()).boxed());
|
||||
};
|
||||
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeRequestBody)?;
|
||||
.map_err(|e| FailedToBufferBody::from_err(e).boxed())?;
|
||||
|
||||
Ok(BytesMaxLength(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = INTERNAL_SERVER_ERROR]
|
||||
#[body = "No url params found for matched route. This is a bug in tower-web. Please open an issue"]
|
||||
pub struct MissingRouteParams(());
|
||||
}
|
||||
|
||||
pub struct UrlParamsMap(HashMap<String, String>);
|
||||
|
||||
impl UrlParamsMap {
|
||||
|
@ -217,8 +334,10 @@ impl UrlParamsMap {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FromRequest for UrlParamsMap {
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||
impl FromRequest<Body> for UrlParamsMap {
|
||||
type Rejection = MissingRouteParams;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
if let Some(params) = req
|
||||
.extensions_mut()
|
||||
.get_mut::<Option<crate::routing::UrlParams>>()
|
||||
|
@ -226,62 +345,78 @@ impl FromRequest for UrlParamsMap {
|
|||
let params = params.take().expect("params already taken").0;
|
||||
Ok(Self(params.into_iter().collect()))
|
||||
} else {
|
||||
panic!("no url params found for matched route. This is a bug in tower-web")
|
||||
Err(MissingRouteParams(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UrlParams<T>(T);
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidUrlParam {
|
||||
type_name: &'static str,
|
||||
}
|
||||
|
||||
impl<T> UrlParams<T> {
|
||||
pub fn into_inner(self) -> T {
|
||||
self.0
|
||||
impl InvalidUrlParam {
|
||||
fn new<T>() -> Self {
|
||||
InvalidUrlParam {
|
||||
type_name: std::any::type_name::<T>(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for InvalidUrlParam {
|
||||
fn into_response(self) -> http::Response<Body> {
|
||||
let mut res = http::Response::new(Body::from(format!(
|
||||
"Invalid URL param. Expected something of type `{}`",
|
||||
self.type_name
|
||||
)));
|
||||
*res.status_mut() = http::StatusCode::BAD_REQUEST;
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UrlParams<T>(T);
|
||||
|
||||
macro_rules! impl_parse_url {
|
||||
() => {};
|
||||
|
||||
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||
#[async_trait]
|
||||
impl<$head, $($tail,)*> FromRequest for UrlParams<($head, $($tail,)*)>
|
||||
impl<$head, $($tail,)*> FromRequest<Body> for UrlParams<($head, $($tail,)*)>
|
||||
where
|
||||
$head: FromStr + Send,
|
||||
$( $tail: FromStr + Send, )*
|
||||
{
|
||||
type Rejection = BoxIntoResponse<Body>;
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
let params = if let Some(params) = req
|
||||
.extensions_mut()
|
||||
.get_mut::<Option<crate::routing::UrlParams>>()
|
||||
{
|
||||
params.take().expect("params already taken").0
|
||||
} else {
|
||||
panic!("no url params found for matched route. This is a bug in tower-web")
|
||||
return Err(MissingRouteParams(()).boxed())
|
||||
};
|
||||
|
||||
if let [(_, $head), $((_, $tail),)*] = &*params {
|
||||
let $head = if let Ok(x) = $head.parse::<$head>() {
|
||||
x
|
||||
} else {
|
||||
return Err(Error::InvalidUrlParam {
|
||||
type_name: std::any::type_name::<$head>(),
|
||||
});
|
||||
return Err(InvalidUrlParam::new::<$head>().boxed());
|
||||
};
|
||||
|
||||
$(
|
||||
let $tail = if let Ok(x) = $tail.parse::<$tail>() {
|
||||
x
|
||||
} else {
|
||||
return Err(Error::InvalidUrlParam {
|
||||
type_name: std::any::type_name::<$tail>(),
|
||||
});
|
||||
return Err(InvalidUrlParam::new::<$tail>().boxed());
|
||||
};
|
||||
)*
|
||||
|
||||
Ok(UrlParams(($head, $($tail,)*)))
|
||||
} else {
|
||||
panic!("wrong number of url params found for matched route. This is a bug in tower-web")
|
||||
return Err(MissingRouteParams(()).boxed())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -290,4 +425,64 @@ macro_rules! impl_parse_url {
|
|||
};
|
||||
}
|
||||
|
||||
impl_parse_url!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
|
||||
impl_parse_url!(T1, T2, T3, T4, T5, T6);
|
||||
|
||||
impl<T1> UrlParams<(T1,)> {
|
||||
pub fn into_inner(self) -> T1 {
|
||||
(self.0).0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1, T2> UrlParams<(T1, T2)> {
|
||||
pub fn into_inner(self) -> (T1, T2) {
|
||||
((self.0).0, (self.0).1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1, T2, T3> UrlParams<(T1, T2, T3)> {
|
||||
pub fn into_inner(self) -> (T1, T2, T3) {
|
||||
((self.0).0, (self.0).1, (self.0).2)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1, T2, T3, T4> UrlParams<(T1, T2, T3, T4)> {
|
||||
pub fn into_inner(self) -> (T1, T2, T3, T4) {
|
||||
((self.0).0, (self.0).1, (self.0).2, (self.0).3)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1, T2, T3, T4, T5> UrlParams<(T1, T2, T3, T4, T5)> {
|
||||
pub fn into_inner(self) -> (T1, T2, T3, T4, T5) {
|
||||
((self.0).0, (self.0).1, (self.0).2, (self.0).3, (self.0).4)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1, T2, T3, T4, T5, T6> UrlParams<(T1, T2, T3, T4, T5, T6)> {
|
||||
pub fn into_inner(self) -> (T1, T2, T3, T4, T5, T6) {
|
||||
(
|
||||
(self.0).0,
|
||||
(self.0).1,
|
||||
(self.0).2,
|
||||
(self.0).3,
|
||||
(self.0).4,
|
||||
(self.0).5,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = INTERNAL_SERVER_ERROR]
|
||||
#[body = "Cannot have two request body extractors for a single handler"]
|
||||
pub struct BodyAlreadyTaken(());
|
||||
}
|
||||
|
||||
fn take_body(req: &mut Request<Body>) -> Result<Body, BodyAlreadyTaken> {
|
||||
struct BodyAlreadyTakenExt;
|
||||
|
||||
if req.extensions_mut().insert(BodyAlreadyTakenExt).is_some() {
|
||||
Err(BodyAlreadyTaken(()))
|
||||
} else {
|
||||
let body = std::mem::take(req.body_mut());
|
||||
Ok(body)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ use async_trait::async_trait;
|
|||
use futures_util::future;
|
||||
use http::{Request, Response};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
task::{Context, Poll},
|
||||
|
@ -24,7 +25,7 @@ pub trait Handler<B, In>: Sized {
|
|||
#[doc(hidden)]
|
||||
type Sealed: sealed::HiddentTrait;
|
||||
|
||||
async fn call(self, req: Request<Body>) -> Result<Response<B>, Error>;
|
||||
async fn call(self, req: Request<Body>) -> Response<B>;
|
||||
|
||||
fn layer<L>(self, layer: L) -> Layered<L::Service, In>
|
||||
where
|
||||
|
@ -45,7 +46,7 @@ where
|
|||
|
||||
type Sealed = sealed::Hidden;
|
||||
|
||||
async fn call(self, req: Request<Body>) -> Result<Response<B>, Error> {
|
||||
async fn call(self, req: Request<Body>) -> Response<B> {
|
||||
self(req).await.into_response()
|
||||
}
|
||||
}
|
||||
|
@ -61,19 +62,28 @@ macro_rules! impl_handler {
|
|||
F: Fn(Request<Body>, $head, $($tail,)*) -> Fut + Send + Sync,
|
||||
Fut: Future<Output = Res> + Send,
|
||||
Res: IntoResponse<B>,
|
||||
$head: FromRequest + Send,
|
||||
$( $tail: FromRequest + Send, )*
|
||||
$head: FromRequest<B> + Send,
|
||||
$( $tail: FromRequest<B> + Send, )*
|
||||
{
|
||||
type Response = Res;
|
||||
|
||||
type Sealed = sealed::Hidden;
|
||||
|
||||
async fn call(self, mut req: Request<Body>) -> Result<Response<B>, Error> {
|
||||
let $head = $head::from_request(&mut req).await?;
|
||||
async fn call(self, mut req: Request<Body>) -> Response<B> {
|
||||
let $head = match $head::from_request(&mut req).await {
|
||||
Ok(value) => value,
|
||||
Err(rejection) => return rejection.into_response(),
|
||||
};
|
||||
|
||||
$(
|
||||
let $tail = $tail::from_request(&mut req).await?;
|
||||
let $tail = match $tail::from_request(&mut req).await {
|
||||
Ok(value) => value,
|
||||
Err(rejection) => return rejection.into_response(),
|
||||
};
|
||||
)*
|
||||
|
||||
let res = self(req, $head, $($tail,)*).await;
|
||||
|
||||
res.into_response()
|
||||
}
|
||||
}
|
||||
|
@ -102,18 +112,25 @@ where
|
|||
impl<S, B, T> Handler<B, T> for Layered<S, T>
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response<B>> + Send,
|
||||
S::Error: Into<BoxError>,
|
||||
S::Error: IntoResponse<B>,
|
||||
S::Response: IntoResponse<B>,
|
||||
S::Future: Send,
|
||||
{
|
||||
type Response = S::Response;
|
||||
|
||||
type Sealed = sealed::Hidden;
|
||||
|
||||
async fn call(self, req: Request<Body>) -> Result<Self::Response, Error> {
|
||||
self.svc
|
||||
async fn call(self, req: Request<Body>) -> Self::Response {
|
||||
// TODO(david): add tests for nesting services
|
||||
match self
|
||||
.svc
|
||||
.oneshot(req)
|
||||
.await
|
||||
.map_err(|err| Error::Dynamic(err.into()))
|
||||
.map_err(IntoResponse::into_response)
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(res) => res,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -158,7 +175,7 @@ where
|
|||
H::Response: 'static,
|
||||
{
|
||||
type Response = Response<B>;
|
||||
type Error = Error;
|
||||
type Error = Infallible;
|
||||
type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
|
@ -170,9 +187,6 @@ where
|
|||
|
||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
let handler = self.handler.clone();
|
||||
Box::pin(async move {
|
||||
let res = Handler::call(handler, req).await?.into_response()?;
|
||||
Ok(res)
|
||||
})
|
||||
Box::pin(async move { Ok(Handler::call(handler, req).await) })
|
||||
}
|
||||
}
|
||||
|
|
122
src/response.rs
122
src/response.rs
|
@ -1,86 +1,127 @@
|
|||
use crate::{Body, Error};
|
||||
use crate::Body;
|
||||
use bytes::Bytes;
|
||||
use http::{header, HeaderValue, Response};
|
||||
use http::{header, HeaderValue, Response, StatusCode};
|
||||
use serde::Serialize;
|
||||
use std::convert::Infallible;
|
||||
use tower::{util::Either, BoxError};
|
||||
|
||||
pub trait IntoResponse<B> {
|
||||
fn into_response(self) -> Result<Response<B>, Error>;
|
||||
fn into_response(self) -> Response<B>;
|
||||
|
||||
fn boxed(self) -> BoxIntoResponse<B>
|
||||
where
|
||||
Self: Sized + 'static,
|
||||
{
|
||||
BoxIntoResponse(self.into_response())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, T> IntoResponse<B> for Result<T, Error>
|
||||
impl<B> IntoResponse<B> for ()
|
||||
where
|
||||
B: Default,
|
||||
{
|
||||
fn into_response(self) -> Response<B> {
|
||||
Response::new(B::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> IntoResponse<B> for Infallible {
|
||||
fn into_response(self) -> Response<B> {
|
||||
match self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, K, B> IntoResponse<B> for Either<T, K>
|
||||
where
|
||||
T: IntoResponse<B>,
|
||||
K: IntoResponse<B>,
|
||||
{
|
||||
fn into_response(self) -> Result<Response<B>, Error> {
|
||||
self.and_then(IntoResponse::into_response)
|
||||
fn into_response(self) -> Response<B> {
|
||||
match self {
|
||||
Either::A(inner) => inner.into_response(),
|
||||
Either::B(inner) => inner.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, T, E> IntoResponse<B> for Result<T, E>
|
||||
where
|
||||
T: IntoResponse<B>,
|
||||
E: IntoResponse<B>,
|
||||
{
|
||||
fn into_response(self) -> Response<B> {
|
||||
match self {
|
||||
Ok(value) => value.into_response(),
|
||||
Err(err) => err.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> IntoResponse<B> for Response<B> {
|
||||
fn into_response(self) -> Result<Response<B>, Error> {
|
||||
Ok(self)
|
||||
fn into_response(self) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for &'static str {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
fn into_response(self) -> Response<Body> {
|
||||
Response::new(Body::from(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for String {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
fn into_response(self) -> Response<Body> {
|
||||
Response::new(Body::from(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for std::borrow::Cow<'static, str> {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
fn into_response(self) -> Response<Body> {
|
||||
Response::new(Body::from(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for Bytes {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let mut res = Response::new(Body::from(self));
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
Ok(res)
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for &'static [u8] {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let mut res = Response::new(Body::from(self));
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
Ok(res)
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for Vec<u8> {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let mut res = Response::new(Body::from(self));
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
Ok(res)
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for std::borrow::Cow<'static, [u8]> {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let mut res = Response::new(Body::from(self));
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
Ok(res)
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,14 +131,23 @@ impl<T> IntoResponse<Body> for Json<T>
|
|||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
let bytes = serde_json::to_vec(&self.0).map_err(Error::SerializeResponseBody)?;
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let bytes = match serde_json::to_vec(&self.0) {
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
return Response::builder()
|
||||
.header(header::CONTENT_TYPE, "text/plain")
|
||||
.body(Body::from(err.to_string()))
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
let mut res = Response::new(Body::from(bytes));
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/json"),
|
||||
);
|
||||
Ok(res)
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,20 +157,28 @@ impl<T> IntoResponse<Body> for Html<T>
|
|||
where
|
||||
T: Into<Bytes>,
|
||||
{
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let bytes = self.0.into();
|
||||
let mut res = Response::new(Body::from(bytes));
|
||||
res.headers_mut()
|
||||
.insert(header::CONTENT_TYPE, HeaderValue::from_static("text/html"));
|
||||
Ok(res)
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct Empty;
|
||||
pub struct BoxIntoResponse<B>(Response<B>);
|
||||
|
||||
impl IntoResponse<Body> for Empty {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::empty()))
|
||||
impl<B> IntoResponse<B> for BoxIntoResponse<B> {
|
||||
fn into_response(self) -> Response<B> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for BoxError {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::from(self.to_string()))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
|
35
src/tests.rs
35
src/tests.rs
|
@ -1,4 +1,4 @@
|
|||
use crate::{app, extract, response};
|
||||
use crate::{app, extract};
|
||||
use http::{Request, Response, StatusCode};
|
||||
use hyper::{Body, Server};
|
||||
use serde::Deserialize;
|
||||
|
@ -10,7 +10,7 @@ use tower::{make::Shared, BoxError, Service};
|
|||
async fn hello_world() {
|
||||
let app = app()
|
||||
.at("/")
|
||||
.get(|_: Request<Body>| async { Ok("Hello, World!") })
|
||||
.get(|_: Request<Body>| async { "Hello, World!" })
|
||||
.into_service();
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
@ -25,7 +25,7 @@ async fn hello_world() {
|
|||
async fn consume_body() {
|
||||
let app = app()
|
||||
.at("/")
|
||||
.get(|_: Request<Body>, body: String| async { Ok(body) })
|
||||
.get(|_: Request<Body>, body: String| async { body })
|
||||
.into_service();
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
@ -51,7 +51,7 @@ async fn deserialize_body() {
|
|||
|
||||
let app = app()
|
||||
.at("/")
|
||||
.post(|_: Request<Body>, input: extract::Json<Input>| async { Ok(input.into_inner().foo) })
|
||||
.post(|_: Request<Body>, input: extract::Json<Input>| async { input.into_inner().foo })
|
||||
.into_service();
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
@ -79,7 +79,7 @@ async fn consume_body_to_json_requires_json_content_type() {
|
|||
.at("/")
|
||||
.post(|_: Request<Body>, input: extract::Json<Input>| async {
|
||||
let input = input.into_inner();
|
||||
Ok(input.foo)
|
||||
input.foo
|
||||
})
|
||||
.into_service();
|
||||
|
||||
|
@ -93,8 +93,10 @@ async fn consume_body_to_json_requires_json_content_type() {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
// TODO(david): is this the most appropriate response code?
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
let status = res.status();
|
||||
dbg!(res.text().await.unwrap());
|
||||
|
||||
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -113,7 +115,6 @@ async fn body_with_length_limit() {
|
|||
.post(
|
||||
|req: Request<Body>, _body: extract::BytesMaxLength<LIMIT>| async move {
|
||||
dbg!(&req);
|
||||
Ok(response::Empty)
|
||||
},
|
||||
)
|
||||
.into_service();
|
||||
|
@ -161,12 +162,12 @@ async fn body_with_length_limit() {
|
|||
async fn routing() {
|
||||
let app = app()
|
||||
.at("/users")
|
||||
.get(|_: Request<Body>| async { Ok("users#index") })
|
||||
.post(|_: Request<Body>| async { Ok("users#create") })
|
||||
.get(|_: Request<Body>| async { "users#index" })
|
||||
.post(|_: Request<Body>| async { "users#create" })
|
||||
.at("/users/:id")
|
||||
.get(|_: Request<Body>| async { Ok("users#show") })
|
||||
.get(|_: Request<Body>| async { "users#show" })
|
||||
.at("/users/:id/action")
|
||||
.get(|_: Request<Body>| async { Ok("users#action") })
|
||||
.get(|_: Request<Body>| async { "users#action" })
|
||||
.into_service();
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
@ -215,18 +216,14 @@ async fn extracting_url_params() {
|
|||
.at("/users/:id")
|
||||
.get(
|
||||
|_: Request<Body>, params: extract::UrlParams<(i32,)>| async move {
|
||||
let (id,) = params.into_inner();
|
||||
let id = params.into_inner();
|
||||
assert_eq!(id, 42);
|
||||
|
||||
Ok(response::Empty)
|
||||
},
|
||||
)
|
||||
.post(
|
||||
|_: Request<Body>, params_map: extract::UrlParamsMap| async move {
|
||||
assert_eq!(params_map.get("id").unwrap(), "1337");
|
||||
assert_eq!(params_map.get_typed::<i32>("id").unwrap(), 1337);
|
||||
|
||||
Ok(response::Empty)
|
||||
},
|
||||
)
|
||||
.into_service();
|
||||
|
@ -254,9 +251,9 @@ async fn extracting_url_params() {
|
|||
async fn boxing() {
|
||||
let app = app()
|
||||
.at("/")
|
||||
.get(|_: Request<Body>| async { Ok("hi from GET") })
|
||||
.get(|_: Request<Body>| async { "hi from GET" })
|
||||
.boxed()
|
||||
.post(|_: Request<Body>| async { Ok("hi from POST") })
|
||||
.post(|_: Request<Body>| async { "hi from POST" })
|
||||
.into_service();
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
|
Loading…
Reference in a new issue