mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-28 07:20:12 +01:00
Generic response types
This commit is contained in:
parent
bf3d4f4a40
commit
433b3183c4
1 changed files with 177 additions and 92 deletions
269
src/lib.rs
269
src/lib.rs
|
@ -4,15 +4,10 @@
|
|||
|
||||
Improvements to make:
|
||||
|
||||
Somehow return generic "into response" kinda types without having to manually
|
||||
create hyper::Body for everything
|
||||
Break stuff up into modules
|
||||
|
||||
Support extracting headers, perhaps via `headers::Header`?
|
||||
|
||||
Body to bytes extractor
|
||||
|
||||
Implement `FromRequest` for more functions, with macro
|
||||
|
||||
Tests
|
||||
|
||||
*/
|
||||
|
@ -20,7 +15,7 @@ Tests
|
|||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use futures_util::{future, ready};
|
||||
use http::{Method, Request, Response, StatusCode};
|
||||
use http::{header, HeaderValue, Method, Request, Response, StatusCode};
|
||||
use http_body::Body as _;
|
||||
use pin_project::pin_project;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
|
@ -69,9 +64,9 @@ pub struct RouteAt<R> {
|
|||
}
|
||||
|
||||
impl<R> RouteAt<R> {
|
||||
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
||||
pub fn get<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
|
||||
where
|
||||
F: Handler<T>,
|
||||
F: Handler<B, T>,
|
||||
{
|
||||
self.add_route(handler_fn, Method::GET)
|
||||
}
|
||||
|
@ -84,9 +79,9 @@ impl<R> RouteAt<R> {
|
|||
self.add_route_service(service, Method::GET)
|
||||
}
|
||||
|
||||
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
||||
pub fn post<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
|
||||
where
|
||||
F: Handler<T>,
|
||||
F: Handler<B, T>,
|
||||
{
|
||||
self.add_route(handler_fn, Method::POST)
|
||||
}
|
||||
|
@ -99,9 +94,13 @@ impl<R> RouteAt<R> {
|
|||
self.add_route_service(service, Method::POST)
|
||||
}
|
||||
|
||||
fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>>
|
||||
fn add_route<H, B, T>(
|
||||
self,
|
||||
handler: H,
|
||||
method: Method,
|
||||
) -> RouteBuilder<Route<HandlerSvc<H, B, T>, R>>
|
||||
where
|
||||
H: Handler<T>,
|
||||
H: Handler<B, T>,
|
||||
{
|
||||
self.add_route_service(HandlerSvc::new(handler), method)
|
||||
}
|
||||
|
@ -149,9 +148,9 @@ impl<R> RouteBuilder<R> {
|
|||
self.app.at(route_spec)
|
||||
}
|
||||
|
||||
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
||||
pub fn get<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
|
||||
where
|
||||
F: Handler<T>,
|
||||
F: Handler<B, T>,
|
||||
{
|
||||
self.app.at_bytes(self.route_spec).get(handler_fn)
|
||||
}
|
||||
|
@ -164,9 +163,9 @@ impl<R> RouteBuilder<R> {
|
|||
self.app.at_bytes(self.route_spec).get_service(service)
|
||||
}
|
||||
|
||||
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
||||
pub fn post<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
|
||||
where
|
||||
F: Handler<T>,
|
||||
F: Handler<B, T>,
|
||||
{
|
||||
self.app.at_bytes(self.route_spec).post(handler_fn)
|
||||
}
|
||||
|
@ -193,6 +192,9 @@ pub enum Error {
|
|||
#[error("failed to deserialize the request body")]
|
||||
DeserializeRequestBody(#[source] serde_json::Error),
|
||||
|
||||
#[error("failed to serialize the response body")]
|
||||
SerializeResponseBody(#[source] serde_json::Error),
|
||||
|
||||
#[error("failed to consume the body")]
|
||||
ConsumeRequestBody(#[source] hyper::Error),
|
||||
|
||||
|
@ -225,8 +227,8 @@ mod sealed {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Handler<In>: Sized {
|
||||
type Response: IntoResponse;
|
||||
pub trait Handler<B, In>: Sized {
|
||||
type Response: IntoResponse<B>;
|
||||
|
||||
// This seals the trait. We cannot use the regular "sealed super trait" approach
|
||||
// due to coherence.
|
||||
|
@ -237,37 +239,94 @@ pub trait Handler<In>: Sized {
|
|||
|
||||
fn layer<L>(self, layer: L) -> Layered<L::Service, In>
|
||||
where
|
||||
L: Layer<HandlerSvc<Self, In>>,
|
||||
L: Layer<HandlerSvc<Self, B, In>>,
|
||||
{
|
||||
Layered::new(layer.layer(HandlerSvc::new(self)))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IntoResponse {
|
||||
fn into_response(self) -> Response<Body>;
|
||||
pub trait IntoResponse<B> {
|
||||
fn into_response(self) -> Result<Response<B>, Error>;
|
||||
}
|
||||
|
||||
impl<B> IntoResponse for Response<B>
|
||||
where
|
||||
B: Into<Body>,
|
||||
{
|
||||
fn into_response(self) -> Response<Body> {
|
||||
self.map(Into::into)
|
||||
impl<B> IntoResponse<B> for Response<B> {
|
||||
fn into_response(self) -> Result<Response<B>, Error> {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for String {
|
||||
fn into_response(self) -> Response<Body> {
|
||||
Response::new(Body::from(self))
|
||||
impl IntoResponse<Body> for &'static str {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for String {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for Bytes {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for &'static [u8] {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for Vec<u8> {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for std::borrow::Cow<'static, str> {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for std::borrow::Cow<'static, [u8]> {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(david): rename this to Json when its in another module
|
||||
pub struct JsonBody<T>(T);
|
||||
|
||||
impl<T> IntoResponse<Body> for JsonBody<T>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
let bytes = serde_json::to_vec(&self.0).map_err(Error::SerializeResponseBody)?;
|
||||
let len = bytes.len();
|
||||
let mut res = Response::new(Body::from(bytes));
|
||||
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
res.headers_mut()
|
||||
.insert(header::CONTENT_LENGTH, HeaderValue::from(len));
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, Fut, Res> Handler<()> for F
|
||||
impl<F, Fut, B, Res> Handler<B, ()> for F
|
||||
where
|
||||
F: Fn(Request<Body>) -> Fut + Send + Sync,
|
||||
Fut: Future<Output = Result<Res, Error>> + Send,
|
||||
Res: IntoResponse,
|
||||
Res: IntoResponse<B>,
|
||||
{
|
||||
type Response = Res;
|
||||
|
||||
|
@ -282,11 +341,11 @@ macro_rules! impl_handler {
|
|||
( $head:ident $(,)? ) => {
|
||||
#[async_trait]
|
||||
#[allow(non_snake_case)]
|
||||
impl<F, Fut, Res, $head> Handler<($head,)> for F
|
||||
impl<F, Fut, B, Res, $head> Handler<B, ($head,)> for F
|
||||
where
|
||||
F: Fn(Request<Body>, $head) -> Fut + Send + Sync,
|
||||
Fut: Future<Output = Result<Res, Error>> + Send,
|
||||
Res: IntoResponse,
|
||||
Res: IntoResponse<B>,
|
||||
$head: FromRequest + Send,
|
||||
{
|
||||
type Response = Res;
|
||||
|
@ -304,11 +363,11 @@ macro_rules! impl_handler {
|
|||
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||
#[async_trait]
|
||||
#[allow(non_snake_case)]
|
||||
impl<F, Fut, Res, $head, $($tail,)*> Handler<($head, $($tail,)*)> for F
|
||||
impl<F, Fut, B, Res, $head, $($tail,)*> Handler<B, ($head, $($tail,)*)> for F
|
||||
where
|
||||
F: Fn(Request<Body>, $head, $($tail,)*) -> Fut + Send + Sync,
|
||||
Fut: Future<Output = Result<Res, Error>> + Send,
|
||||
Res: IntoResponse,
|
||||
Res: IntoResponse<B>,
|
||||
$head: FromRequest + Send,
|
||||
$( $tail: FromRequest + Send, )*
|
||||
{
|
||||
|
@ -347,10 +406,9 @@ where
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S, T> Handler<T> for Layered<S, T>
|
||||
impl<S, B, T> Handler<B, T> for Layered<S, T>
|
||||
where
|
||||
S: Service<Request<Body>> + Send,
|
||||
S::Response: IntoResponse,
|
||||
S: Service<Request<Body>, Response = Response<B>> + Send,
|
||||
S::Error: Into<BoxError>,
|
||||
S::Future: Send,
|
||||
{
|
||||
|
@ -375,12 +433,12 @@ impl<S, T> Layered<S, T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct HandlerSvc<H, T> {
|
||||
pub struct HandlerSvc<H, B, T> {
|
||||
handler: H,
|
||||
_input: PhantomData<fn() -> T>,
|
||||
_input: PhantomData<fn() -> (B, T)>,
|
||||
}
|
||||
|
||||
impl<H, T> HandlerSvc<H, T> {
|
||||
impl<H, B, T> HandlerSvc<H, B, T> {
|
||||
fn new(handler: H) -> Self {
|
||||
Self {
|
||||
handler,
|
||||
|
@ -389,7 +447,7 @@ impl<H, T> HandlerSvc<H, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<H, T> Clone for HandlerSvc<H, T>
|
||||
impl<H, B, T> Clone for HandlerSvc<H, B, T>
|
||||
where
|
||||
H: Clone,
|
||||
{
|
||||
|
@ -401,25 +459,27 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<H, T> Service<Request<Body>> for HandlerSvc<H, T>
|
||||
impl<H, B, T> Service<Request<Body>> for HandlerSvc<H, B, T>
|
||||
where
|
||||
H: Handler<T> + Clone + Send + 'static,
|
||||
H: Handler<B, T> + Clone + Send + 'static,
|
||||
H::Response: 'static,
|
||||
{
|
||||
type Response = Response<Body>;
|
||||
type Response = Response<B>;
|
||||
type Error = Error;
|
||||
type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// HandlerSvc can only be constructed from async functions which are always ready
|
||||
// HandlerSvc can only be constructed from async functions which are always ready, or from
|
||||
// `Layered` which bufferes in `<Layered as Handler>::call` and is therefore also always
|
||||
// ready.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
let handler = self.handler.clone();
|
||||
Box::pin(async move {
|
||||
let res = Handler::call(handler, req).await?;
|
||||
Ok(res.into_response())
|
||||
let res = Handler::call(handler, req).await?.into_response()?;
|
||||
Ok(res)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -543,6 +603,31 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(david): rename this to Bytes when its in another module
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BytesBody(Bytes);
|
||||
|
||||
impl BytesBody {
|
||||
pub fn into_inner(self) -> Bytes {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRequest for BytesBody {
|
||||
type Future = future::BoxFuture<'static, Result<Self, Error>>;
|
||||
|
||||
fn from_request(req: &mut Request<Body>) -> Self::Future {
|
||||
let body = std::mem::take(req.body_mut());
|
||||
|
||||
Box::pin(async move {
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeRequestBody)?;
|
||||
Ok(BytesBody(bytes))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct EmptyRouter(());
|
||||
|
||||
|
@ -771,7 +856,9 @@ where
|
|||
| Error::QueryStringMissing
|
||||
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
|
||||
|
||||
Error::MissingExtension { .. } => make_response(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => {
|
||||
make_response(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
Error::Service(err) => match err.downcast::<Error>() {
|
||||
Ok(err) => Err(*err),
|
||||
|
@ -820,46 +907,44 @@ mod tests {
|
|||
Ok(Response::new(Body::empty()))
|
||||
}
|
||||
|
||||
async fn users_index(
|
||||
_: Request<Body>,
|
||||
pagination: Query<Pagination>,
|
||||
) -> Result<String, Error> {
|
||||
let pagination = pagination.into_inner();
|
||||
assert_eq!(pagination.page, 1);
|
||||
assert_eq!(pagination.per_page, 30);
|
||||
Ok::<_, Error>("users#index".to_string())
|
||||
}
|
||||
|
||||
let app = app()
|
||||
// routes with functions
|
||||
.at("/")
|
||||
.get(root)
|
||||
// routes with closures
|
||||
.at("/users")
|
||||
.get(users_index)
|
||||
.post(
|
||||
|_: Request<Body>,
|
||||
payload: Json<UsersCreate>,
|
||||
_state: Extension<Arc<State>>| async {
|
||||
let payload = payload.into_inner();
|
||||
assert_eq!(payload.username, "bob");
|
||||
Ok::<_, Error>(Response::new(Body::from("users#create")))
|
||||
},
|
||||
)
|
||||
// routes with a service
|
||||
.at("/service")
|
||||
.get_service(service_fn(root))
|
||||
// routes with layers applied
|
||||
.at("/large-static-file")
|
||||
.get(
|
||||
large_static_file.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(30)))
|
||||
.layer(CompressionLayer::new())
|
||||
.into_inner(),
|
||||
),
|
||||
)
|
||||
.into_service();
|
||||
let app =
|
||||
app()
|
||||
// routes with functions
|
||||
.at("/")
|
||||
.get(root)
|
||||
// routes with closures
|
||||
.at("/users")
|
||||
.get(|_: Request<Body>, pagination: Query<Pagination>| async {
|
||||
let pagination = pagination.into_inner();
|
||||
assert_eq!(pagination.page, 1);
|
||||
assert_eq!(pagination.per_page, 30);
|
||||
Ok::<_, Error>("users#index".to_string())
|
||||
})
|
||||
.post(
|
||||
|_: Request<Body>,
|
||||
payload: Json<UsersCreate>,
|
||||
_state: Extension<Arc<State>>| async {
|
||||
let payload = payload.into_inner();
|
||||
assert_eq!(payload.username, "bob");
|
||||
Ok::<_, Error>(JsonBody(
|
||||
serde_json::json!({ "username": payload.username }),
|
||||
))
|
||||
},
|
||||
)
|
||||
// routes with a service
|
||||
.at("/service")
|
||||
.get_service(service_fn(root))
|
||||
// routes with layers applied
|
||||
.at("/large-static-file")
|
||||
.get(
|
||||
large_static_file.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(30)))
|
||||
.layer(CompressionLayer::new())
|
||||
.into_inner(),
|
||||
),
|
||||
)
|
||||
.into_service();
|
||||
|
||||
// state shared by all routes, could hold db connection etc
|
||||
struct State {}
|
||||
|
@ -934,7 +1019,7 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(body_to_string(res).await, "users#create");
|
||||
assert_eq!(body_to_string(res).await, r#"{"username":"bob"}"#);
|
||||
}
|
||||
|
||||
async fn body_to_string<B>(res: Response<B>) -> String
|
||||
|
|
Loading…
Reference in a new issue