mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-29 15:49:16 +01:00
Generic response types
This commit is contained in:
parent
bf3d4f4a40
commit
433b3183c4
1 changed files with 177 additions and 92 deletions
215
src/lib.rs
215
src/lib.rs
|
@ -4,15 +4,10 @@
|
||||||
|
|
||||||
Improvements to make:
|
Improvements to make:
|
||||||
|
|
||||||
Somehow return generic "into response" kinda types without having to manually
|
Break stuff up into modules
|
||||||
create hyper::Body for everything
|
|
||||||
|
|
||||||
Support extracting headers, perhaps via `headers::Header`?
|
Support extracting headers, perhaps via `headers::Header`?
|
||||||
|
|
||||||
Body to bytes extractor
|
|
||||||
|
|
||||||
Implement `FromRequest` for more functions, with macro
|
|
||||||
|
|
||||||
Tests
|
Tests
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
@ -20,7 +15,7 @@ Tests
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures_util::{future, ready};
|
use futures_util::{future, ready};
|
||||||
use http::{Method, Request, Response, StatusCode};
|
use http::{header, HeaderValue, Method, Request, Response, StatusCode};
|
||||||
use http_body::Body as _;
|
use http_body::Body as _;
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
|
@ -69,9 +64,9 @@ pub struct RouteAt<R> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> RouteAt<R> {
|
impl<R> RouteAt<R> {
|
||||||
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
pub fn get<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
|
||||||
where
|
where
|
||||||
F: Handler<T>,
|
F: Handler<B, T>,
|
||||||
{
|
{
|
||||||
self.add_route(handler_fn, Method::GET)
|
self.add_route(handler_fn, Method::GET)
|
||||||
}
|
}
|
||||||
|
@ -84,9 +79,9 @@ impl<R> RouteAt<R> {
|
||||||
self.add_route_service(service, Method::GET)
|
self.add_route_service(service, Method::GET)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
pub fn post<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
|
||||||
where
|
where
|
||||||
F: Handler<T>,
|
F: Handler<B, T>,
|
||||||
{
|
{
|
||||||
self.add_route(handler_fn, Method::POST)
|
self.add_route(handler_fn, Method::POST)
|
||||||
}
|
}
|
||||||
|
@ -99,9 +94,13 @@ impl<R> RouteAt<R> {
|
||||||
self.add_route_service(service, Method::POST)
|
self.add_route_service(service, Method::POST)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>>
|
fn add_route<H, B, T>(
|
||||||
|
self,
|
||||||
|
handler: H,
|
||||||
|
method: Method,
|
||||||
|
) -> RouteBuilder<Route<HandlerSvc<H, B, T>, R>>
|
||||||
where
|
where
|
||||||
H: Handler<T>,
|
H: Handler<B, T>,
|
||||||
{
|
{
|
||||||
self.add_route_service(HandlerSvc::new(handler), method)
|
self.add_route_service(HandlerSvc::new(handler), method)
|
||||||
}
|
}
|
||||||
|
@ -149,9 +148,9 @@ impl<R> RouteBuilder<R> {
|
||||||
self.app.at(route_spec)
|
self.app.at(route_spec)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
pub fn get<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
|
||||||
where
|
where
|
||||||
F: Handler<T>,
|
F: Handler<B, T>,
|
||||||
{
|
{
|
||||||
self.app.at_bytes(self.route_spec).get(handler_fn)
|
self.app.at_bytes(self.route_spec).get(handler_fn)
|
||||||
}
|
}
|
||||||
|
@ -164,9 +163,9 @@ impl<R> RouteBuilder<R> {
|
||||||
self.app.at_bytes(self.route_spec).get_service(service)
|
self.app.at_bytes(self.route_spec).get_service(service)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
pub fn post<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
|
||||||
where
|
where
|
||||||
F: Handler<T>,
|
F: Handler<B, T>,
|
||||||
{
|
{
|
||||||
self.app.at_bytes(self.route_spec).post(handler_fn)
|
self.app.at_bytes(self.route_spec).post(handler_fn)
|
||||||
}
|
}
|
||||||
|
@ -193,6 +192,9 @@ pub enum Error {
|
||||||
#[error("failed to deserialize the request body")]
|
#[error("failed to deserialize the request body")]
|
||||||
DeserializeRequestBody(#[source] serde_json::Error),
|
DeserializeRequestBody(#[source] serde_json::Error),
|
||||||
|
|
||||||
|
#[error("failed to serialize the response body")]
|
||||||
|
SerializeResponseBody(#[source] serde_json::Error),
|
||||||
|
|
||||||
#[error("failed to consume the body")]
|
#[error("failed to consume the body")]
|
||||||
ConsumeRequestBody(#[source] hyper::Error),
|
ConsumeRequestBody(#[source] hyper::Error),
|
||||||
|
|
||||||
|
@ -225,8 +227,8 @@ mod sealed {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Handler<In>: Sized {
|
pub trait Handler<B, In>: Sized {
|
||||||
type Response: IntoResponse;
|
type Response: IntoResponse<B>;
|
||||||
|
|
||||||
// This seals the trait. We cannot use the regular "sealed super trait" approach
|
// This seals the trait. We cannot use the regular "sealed super trait" approach
|
||||||
// due to coherence.
|
// due to coherence.
|
||||||
|
@ -237,37 +239,94 @@ pub trait Handler<In>: Sized {
|
||||||
|
|
||||||
fn layer<L>(self, layer: L) -> Layered<L::Service, In>
|
fn layer<L>(self, layer: L) -> Layered<L::Service, In>
|
||||||
where
|
where
|
||||||
L: Layer<HandlerSvc<Self, In>>,
|
L: Layer<HandlerSvc<Self, B, In>>,
|
||||||
{
|
{
|
||||||
Layered::new(layer.layer(HandlerSvc::new(self)))
|
Layered::new(layer.layer(HandlerSvc::new(self)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait IntoResponse {
|
pub trait IntoResponse<B> {
|
||||||
fn into_response(self) -> Response<Body>;
|
fn into_response(self) -> Result<Response<B>, Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B> IntoResponse for Response<B>
|
impl<B> IntoResponse<B> for Response<B> {
|
||||||
|
fn into_response(self) -> Result<Response<B>, Error> {
|
||||||
|
Ok(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse<Body> for &'static str {
|
||||||
|
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||||
|
Ok(Response::new(Body::from(self)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse<Body> for String {
|
||||||
|
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||||
|
Ok(Response::new(Body::from(self)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse<Body> for Bytes {
|
||||||
|
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||||
|
Ok(Response::new(Body::from(self)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse<Body> for &'static [u8] {
|
||||||
|
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||||
|
Ok(Response::new(Body::from(self)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse<Body> for Vec<u8> {
|
||||||
|
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||||
|
Ok(Response::new(Body::from(self)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse<Body> for std::borrow::Cow<'static, str> {
|
||||||
|
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||||
|
Ok(Response::new(Body::from(self)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse<Body> for std::borrow::Cow<'static, [u8]> {
|
||||||
|
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||||
|
Ok(Response::new(Body::from(self)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(david): rename this to Json when its in another module
|
||||||
|
pub struct JsonBody<T>(T);
|
||||||
|
|
||||||
|
impl<T> IntoResponse<Body> for JsonBody<T>
|
||||||
where
|
where
|
||||||
B: Into<Body>,
|
T: Serialize,
|
||||||
{
|
{
|
||||||
fn into_response(self) -> Response<Body> {
|
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||||
self.map(Into::into)
|
let bytes = serde_json::to_vec(&self.0).map_err(Error::SerializeResponseBody)?;
|
||||||
}
|
let len = bytes.len();
|
||||||
}
|
let mut res = Response::new(Body::from(bytes));
|
||||||
|
|
||||||
impl IntoResponse for String {
|
res.headers_mut().insert(
|
||||||
fn into_response(self) -> Response<Body> {
|
header::CONTENT_TYPE,
|
||||||
Response::new(Body::from(self))
|
HeaderValue::from_static("application/json"),
|
||||||
|
);
|
||||||
|
|
||||||
|
res.headers_mut()
|
||||||
|
.insert(header::CONTENT_LENGTH, HeaderValue::from(len));
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<F, Fut, Res> Handler<()> for F
|
impl<F, Fut, B, Res> Handler<B, ()> for F
|
||||||
where
|
where
|
||||||
F: Fn(Request<Body>) -> Fut + Send + Sync,
|
F: Fn(Request<Body>) -> Fut + Send + Sync,
|
||||||
Fut: Future<Output = Result<Res, Error>> + Send,
|
Fut: Future<Output = Result<Res, Error>> + Send,
|
||||||
Res: IntoResponse,
|
Res: IntoResponse<B>,
|
||||||
{
|
{
|
||||||
type Response = Res;
|
type Response = Res;
|
||||||
|
|
||||||
|
@ -282,11 +341,11 @@ macro_rules! impl_handler {
|
||||||
( $head:ident $(,)? ) => {
|
( $head:ident $(,)? ) => {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
impl<F, Fut, Res, $head> Handler<($head,)> for F
|
impl<F, Fut, B, Res, $head> Handler<B, ($head,)> for F
|
||||||
where
|
where
|
||||||
F: Fn(Request<Body>, $head) -> Fut + Send + Sync,
|
F: Fn(Request<Body>, $head) -> Fut + Send + Sync,
|
||||||
Fut: Future<Output = Result<Res, Error>> + Send,
|
Fut: Future<Output = Result<Res, Error>> + Send,
|
||||||
Res: IntoResponse,
|
Res: IntoResponse<B>,
|
||||||
$head: FromRequest + Send,
|
$head: FromRequest + Send,
|
||||||
{
|
{
|
||||||
type Response = Res;
|
type Response = Res;
|
||||||
|
@ -304,11 +363,11 @@ macro_rules! impl_handler {
|
||||||
( $head:ident, $($tail:ident),* $(,)? ) => {
|
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
impl<F, Fut, Res, $head, $($tail,)*> Handler<($head, $($tail,)*)> for F
|
impl<F, Fut, B, Res, $head, $($tail,)*> Handler<B, ($head, $($tail,)*)> for F
|
||||||
where
|
where
|
||||||
F: Fn(Request<Body>, $head, $($tail,)*) -> Fut + Send + Sync,
|
F: Fn(Request<Body>, $head, $($tail,)*) -> Fut + Send + Sync,
|
||||||
Fut: Future<Output = Result<Res, Error>> + Send,
|
Fut: Future<Output = Result<Res, Error>> + Send,
|
||||||
Res: IntoResponse,
|
Res: IntoResponse<B>,
|
||||||
$head: FromRequest + Send,
|
$head: FromRequest + Send,
|
||||||
$( $tail: FromRequest + Send, )*
|
$( $tail: FromRequest + Send, )*
|
||||||
{
|
{
|
||||||
|
@ -347,10 +406,9 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<S, T> Handler<T> for Layered<S, T>
|
impl<S, B, T> Handler<B, T> for Layered<S, T>
|
||||||
where
|
where
|
||||||
S: Service<Request<Body>> + Send,
|
S: Service<Request<Body>, Response = Response<B>> + Send,
|
||||||
S::Response: IntoResponse,
|
|
||||||
S::Error: Into<BoxError>,
|
S::Error: Into<BoxError>,
|
||||||
S::Future: Send,
|
S::Future: Send,
|
||||||
{
|
{
|
||||||
|
@ -375,12 +433,12 @@ impl<S, T> Layered<S, T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct HandlerSvc<H, T> {
|
pub struct HandlerSvc<H, B, T> {
|
||||||
handler: H,
|
handler: H,
|
||||||
_input: PhantomData<fn() -> T>,
|
_input: PhantomData<fn() -> (B, T)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<H, T> HandlerSvc<H, T> {
|
impl<H, B, T> HandlerSvc<H, B, T> {
|
||||||
fn new(handler: H) -> Self {
|
fn new(handler: H) -> Self {
|
||||||
Self {
|
Self {
|
||||||
handler,
|
handler,
|
||||||
|
@ -389,7 +447,7 @@ impl<H, T> HandlerSvc<H, T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<H, T> Clone for HandlerSvc<H, T>
|
impl<H, B, T> Clone for HandlerSvc<H, B, T>
|
||||||
where
|
where
|
||||||
H: Clone,
|
H: Clone,
|
||||||
{
|
{
|
||||||
|
@ -401,25 +459,27 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<H, T> Service<Request<Body>> for HandlerSvc<H, T>
|
impl<H, B, T> Service<Request<Body>> for HandlerSvc<H, B, T>
|
||||||
where
|
where
|
||||||
H: Handler<T> + Clone + Send + 'static,
|
H: Handler<B, T> + Clone + Send + 'static,
|
||||||
H::Response: 'static,
|
H::Response: 'static,
|
||||||
{
|
{
|
||||||
type Response = Response<Body>;
|
type Response = Response<B>;
|
||||||
type Error = Error;
|
type Error = Error;
|
||||||
type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||||
|
|
||||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
// HandlerSvc can only be constructed from async functions which are always ready
|
// HandlerSvc can only be constructed from async functions which are always ready, or from
|
||||||
|
// `Layered` which bufferes in `<Layered as Handler>::call` and is therefore also always
|
||||||
|
// ready.
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||||
let handler = self.handler.clone();
|
let handler = self.handler.clone();
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let res = Handler::call(handler, req).await?;
|
let res = Handler::call(handler, req).await?.into_response()?;
|
||||||
Ok(res.into_response())
|
Ok(res)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -543,6 +603,31 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(david): rename this to Bytes when its in another module
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BytesBody(Bytes);
|
||||||
|
|
||||||
|
impl BytesBody {
|
||||||
|
pub fn into_inner(self) -> Bytes {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRequest for BytesBody {
|
||||||
|
type Future = future::BoxFuture<'static, Result<Self, Error>>;
|
||||||
|
|
||||||
|
fn from_request(req: &mut Request<Body>) -> Self::Future {
|
||||||
|
let body = std::mem::take(req.body_mut());
|
||||||
|
|
||||||
|
Box::pin(async move {
|
||||||
|
let bytes = hyper::body::to_bytes(body)
|
||||||
|
.await
|
||||||
|
.map_err(Error::ConsumeRequestBody)?;
|
||||||
|
Ok(BytesBody(bytes))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
pub struct EmptyRouter(());
|
pub struct EmptyRouter(());
|
||||||
|
|
||||||
|
@ -771,7 +856,9 @@ where
|
||||||
| Error::QueryStringMissing
|
| Error::QueryStringMissing
|
||||||
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
|
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
|
||||||
|
|
||||||
Error::MissingExtension { .. } => make_response(StatusCode::INTERNAL_SERVER_ERROR),
|
Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => {
|
||||||
|
make_response(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
|
}
|
||||||
|
|
||||||
Error::Service(err) => match err.downcast::<Error>() {
|
Error::Service(err) => match err.downcast::<Error>() {
|
||||||
Ok(err) => Err(*err),
|
Ok(err) => Err(*err),
|
||||||
|
@ -820,30 +907,28 @@ mod tests {
|
||||||
Ok(Response::new(Body::empty()))
|
Ok(Response::new(Body::empty()))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn users_index(
|
let app =
|
||||||
_: Request<Body>,
|
app()
|
||||||
pagination: Query<Pagination>,
|
|
||||||
) -> Result<String, Error> {
|
|
||||||
let pagination = pagination.into_inner();
|
|
||||||
assert_eq!(pagination.page, 1);
|
|
||||||
assert_eq!(pagination.per_page, 30);
|
|
||||||
Ok::<_, Error>("users#index".to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
let app = app()
|
|
||||||
// routes with functions
|
// routes with functions
|
||||||
.at("/")
|
.at("/")
|
||||||
.get(root)
|
.get(root)
|
||||||
// routes with closures
|
// routes with closures
|
||||||
.at("/users")
|
.at("/users")
|
||||||
.get(users_index)
|
.get(|_: Request<Body>, pagination: Query<Pagination>| async {
|
||||||
|
let pagination = pagination.into_inner();
|
||||||
|
assert_eq!(pagination.page, 1);
|
||||||
|
assert_eq!(pagination.per_page, 30);
|
||||||
|
Ok::<_, Error>("users#index".to_string())
|
||||||
|
})
|
||||||
.post(
|
.post(
|
||||||
|_: Request<Body>,
|
|_: Request<Body>,
|
||||||
payload: Json<UsersCreate>,
|
payload: Json<UsersCreate>,
|
||||||
_state: Extension<Arc<State>>| async {
|
_state: Extension<Arc<State>>| async {
|
||||||
let payload = payload.into_inner();
|
let payload = payload.into_inner();
|
||||||
assert_eq!(payload.username, "bob");
|
assert_eq!(payload.username, "bob");
|
||||||
Ok::<_, Error>(Response::new(Body::from("users#create")))
|
Ok::<_, Error>(JsonBody(
|
||||||
|
serde_json::json!({ "username": payload.username }),
|
||||||
|
))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
// routes with a service
|
// routes with a service
|
||||||
|
@ -934,7 +1019,7 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(res.status(), StatusCode::OK);
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
assert_eq!(body_to_string(res).await, "users#create");
|
assert_eq!(body_to_string(res).await, r#"{"username":"bob"}"#);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn body_to_string<B>(res: Response<B>) -> String
|
async fn body_to_string<B>(res: Response<B>) -> String
|
||||||
|
|
Loading…
Reference in a new issue