mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-01 08:56:15 +01:00
Quality of life improvements
This commit is contained in:
parent
90b9dffce7
commit
ea582ab8d9
6 changed files with 42 additions and 125 deletions
|
@ -17,6 +17,7 @@ serde_json = "1.0"
|
||||||
serde_urlencoded = "0.7"
|
serde_urlencoded = "0.7"
|
||||||
thiserror = "1.0"
|
thiserror = "1.0"
|
||||||
tower = { version = "0.4", features = ["util", "buffer"] }
|
tower = { version = "0.4", features = ["util", "buffer"] }
|
||||||
|
tower-http = { version = "0.1", features = ["add-extension"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
hyper = { version = "0.14", features = ["full"] }
|
hyper = { version = "0.14", features = ["full"] }
|
||||||
|
|
|
@ -51,10 +51,10 @@ async fn get(
|
||||||
params: extract::UrlParams<(String,)>,
|
params: extract::UrlParams<(String,)>,
|
||||||
state: extract::Extension<SharedState>,
|
state: extract::Extension<SharedState>,
|
||||||
) -> Result<Bytes, StatusCode> {
|
) -> Result<Bytes, StatusCode> {
|
||||||
let state = state.into_inner();
|
let state = state.0;
|
||||||
let db = &state.lock().unwrap().db;
|
let db = &state.lock().unwrap().db;
|
||||||
|
|
||||||
let key = params.into_inner();
|
let (key,) = params.0;
|
||||||
|
|
||||||
if let Some(value) = db.get(&key) {
|
if let Some(value) = db.get(&key) {
|
||||||
Ok(value.clone())
|
Ok(value.clone())
|
||||||
|
@ -69,11 +69,11 @@ async fn set(
|
||||||
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
|
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
|
||||||
state: extract::Extension<SharedState>,
|
state: extract::Extension<SharedState>,
|
||||||
) {
|
) {
|
||||||
let state = state.into_inner();
|
let state = state.0;
|
||||||
let db = &mut state.lock().unwrap().db;
|
let db = &mut state.lock().unwrap().db;
|
||||||
|
|
||||||
let key = params.into_inner();
|
let (key,) = params.0;
|
||||||
let value = value.into_inner();
|
let value = value.0;
|
||||||
|
|
||||||
db.insert(key, value);
|
db.insert(key, value);
|
||||||
}
|
}
|
||||||
|
|
128
src/extract.rs
128
src/extract.rs
|
@ -1,10 +1,7 @@
|
||||||
use crate::{
|
use crate::{body::Body, response::IntoResponse};
|
||||||
body::Body,
|
|
||||||
response::{BoxIntoResponse, IntoResponse},
|
|
||||||
};
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http::{header, Request};
|
use http::{header, Response, Request};
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use std::{collections::HashMap, convert::Infallible, str::FromStr};
|
use std::{collections::HashMap, convert::Infallible, str::FromStr};
|
||||||
|
|
||||||
|
@ -80,13 +77,7 @@ define_rejection! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct Query<T>(T);
|
pub struct Query<T>(pub T);
|
||||||
|
|
||||||
impl<T> Query<T> {
|
|
||||||
pub fn into_inner(self) -> T {
|
|
||||||
self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<T> FromRequest<Body> for Query<T>
|
impl<T> FromRequest<Body> for Query<T>
|
||||||
|
@ -103,13 +94,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct Json<T>(T);
|
pub struct Json<T>(pub T);
|
||||||
|
|
||||||
impl<T> Json<T> {
|
|
||||||
pub fn into_inner(self) -> T {
|
|
||||||
self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
define_rejection! {
|
define_rejection! {
|
||||||
#[status = BAD_REQUEST]
|
#[status = BAD_REQUEST]
|
||||||
|
@ -128,24 +113,24 @@ impl<T> FromRequest<Body> for Json<T>
|
||||||
where
|
where
|
||||||
T: DeserializeOwned,
|
T: DeserializeOwned,
|
||||||
{
|
{
|
||||||
type Rejection = BoxIntoResponse<Body>;
|
type Rejection = Response<Body>;
|
||||||
|
|
||||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||||
if has_content_type(&req, "application/json") {
|
if has_content_type(&req, "application/json") {
|
||||||
let body = take_body(req).map_err(IntoResponse::boxed)?;
|
let body = take_body(req).map_err(IntoResponse::into_response)?;
|
||||||
|
|
||||||
let bytes = hyper::body::to_bytes(body)
|
let bytes = hyper::body::to_bytes(body)
|
||||||
.await
|
.await
|
||||||
.map_err(InvalidJsonBody::from_err)
|
.map_err(InvalidJsonBody::from_err)
|
||||||
.map_err(IntoResponse::boxed)?;
|
.map_err(IntoResponse::into_response)?;
|
||||||
|
|
||||||
let value = serde_json::from_slice(&bytes)
|
let value = serde_json::from_slice(&bytes)
|
||||||
.map_err(InvalidJsonBody::from_err)
|
.map_err(InvalidJsonBody::from_err)
|
||||||
.map_err(IntoResponse::boxed)?;
|
.map_err(IntoResponse::into_response)?;
|
||||||
|
|
||||||
Ok(Json(value))
|
Ok(Json(value))
|
||||||
} else {
|
} else {
|
||||||
Err(MissingJsonContentType(()).boxed())
|
Err(MissingJsonContentType(()).into_response())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -173,13 +158,7 @@ define_rejection! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct Extension<T>(T);
|
pub struct Extension<T>(pub T);
|
||||||
|
|
||||||
impl<T> Extension<T> {
|
|
||||||
pub fn into_inner(self) -> T {
|
|
||||||
self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<T> FromRequest<Body> for Extension<T>
|
impl<T> FromRequest<Body> for Extension<T>
|
||||||
|
@ -207,15 +186,15 @@ define_rejection! {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl FromRequest<Body> for Bytes {
|
impl FromRequest<Body> for Bytes {
|
||||||
type Rejection = BoxIntoResponse<Body>;
|
type Rejection = Response<Body>;
|
||||||
|
|
||||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||||
let body = take_body(req).map_err(IntoResponse::boxed)?;
|
let body = take_body(req).map_err(IntoResponse::into_response)?;
|
||||||
|
|
||||||
let bytes = hyper::body::to_bytes(body)
|
let bytes = hyper::body::to_bytes(body)
|
||||||
.await
|
.await
|
||||||
.map_err(FailedToBufferBody::from_err)
|
.map_err(FailedToBufferBody::from_err)
|
||||||
.map_err(IntoResponse::boxed)?;
|
.map_err(IntoResponse::into_response)?;
|
||||||
|
|
||||||
Ok(bytes)
|
Ok(bytes)
|
||||||
}
|
}
|
||||||
|
@ -229,20 +208,20 @@ define_rejection! {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl FromRequest<Body> for String {
|
impl FromRequest<Body> for String {
|
||||||
type Rejection = BoxIntoResponse<Body>;
|
type Rejection = Response<Body>;
|
||||||
|
|
||||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||||
let body = take_body(req).map_err(IntoResponse::boxed)?;
|
let body = take_body(req).map_err(IntoResponse::into_response)?;
|
||||||
|
|
||||||
let bytes = hyper::body::to_bytes(body)
|
let bytes = hyper::body::to_bytes(body)
|
||||||
.await
|
.await
|
||||||
.map_err(FailedToBufferBody::from_err)
|
.map_err(FailedToBufferBody::from_err)
|
||||||
.map_err(IntoResponse::boxed)?
|
.map_err(IntoResponse::into_response)?
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
let string = String::from_utf8(bytes)
|
let string = String::from_utf8(bytes)
|
||||||
.map_err(InvalidUtf8::from_err)
|
.map_err(InvalidUtf8::from_err)
|
||||||
.map_err(IntoResponse::boxed)?;
|
.map_err(IntoResponse::into_response)?;
|
||||||
|
|
||||||
Ok(string)
|
Ok(string)
|
||||||
}
|
}
|
||||||
|
@ -270,36 +249,30 @@ define_rejection! {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct BytesMaxLength<const N: u64>(Bytes);
|
pub struct BytesMaxLength<const N: u64>(pub Bytes);
|
||||||
|
|
||||||
impl<const N: u64> BytesMaxLength<N> {
|
|
||||||
pub fn into_inner(self) -> Bytes {
|
|
||||||
self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<const N: u64> FromRequest<Body> for BytesMaxLength<N> {
|
impl<const N: u64> FromRequest<Body> for BytesMaxLength<N> {
|
||||||
type Rejection = BoxIntoResponse<Body>;
|
type Rejection = Response<Body>;
|
||||||
|
|
||||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||||
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
|
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
|
||||||
let body = take_body(req).map_err(|reject| reject.boxed())?;
|
let body = take_body(req).map_err(|reject| reject.into_response())?;
|
||||||
|
|
||||||
let content_length =
|
let content_length =
|
||||||
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
|
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
|
||||||
|
|
||||||
if let Some(length) = content_length {
|
if let Some(length) = content_length {
|
||||||
if length > N {
|
if length > N {
|
||||||
return Err(PayloadTooLarge(()).boxed());
|
return Err(PayloadTooLarge(()).into_response());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(LengthRequired(()).boxed());
|
return Err(LengthRequired(()).into_response());
|
||||||
};
|
};
|
||||||
|
|
||||||
let bytes = hyper::body::to_bytes(body)
|
let bytes = hyper::body::to_bytes(body)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| FailedToBufferBody::from_err(e).boxed())?;
|
.map_err(|e| FailedToBufferBody::from_err(e).into_response())?;
|
||||||
|
|
||||||
Ok(BytesMaxLength(bytes))
|
Ok(BytesMaxLength(bytes))
|
||||||
}
|
}
|
||||||
|
@ -367,7 +340,7 @@ impl IntoResponse<Body> for InvalidUrlParam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct UrlParams<T>(T);
|
pub struct UrlParams<T>(pub T);
|
||||||
|
|
||||||
macro_rules! impl_parse_url {
|
macro_rules! impl_parse_url {
|
||||||
() => {};
|
() => {};
|
||||||
|
@ -379,7 +352,7 @@ macro_rules! impl_parse_url {
|
||||||
$head: FromStr + Send,
|
$head: FromStr + Send,
|
||||||
$( $tail: FromStr + Send, )*
|
$( $tail: FromStr + Send, )*
|
||||||
{
|
{
|
||||||
type Rejection = BoxIntoResponse<Body>;
|
type Rejection = Response<Body>;
|
||||||
|
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||||
|
@ -389,27 +362,27 @@ macro_rules! impl_parse_url {
|
||||||
{
|
{
|
||||||
params.take().expect("params already taken").0
|
params.take().expect("params already taken").0
|
||||||
} else {
|
} else {
|
||||||
return Err(MissingRouteParams(()).boxed())
|
return Err(MissingRouteParams(()).into_response())
|
||||||
};
|
};
|
||||||
|
|
||||||
if let [(_, $head), $((_, $tail),)*] = &*params {
|
if let [(_, $head), $((_, $tail),)*] = &*params {
|
||||||
let $head = if let Ok(x) = $head.parse::<$head>() {
|
let $head = if let Ok(x) = $head.parse::<$head>() {
|
||||||
x
|
x
|
||||||
} else {
|
} else {
|
||||||
return Err(InvalidUrlParam::new::<$head>().boxed());
|
return Err(InvalidUrlParam::new::<$head>().into_response());
|
||||||
};
|
};
|
||||||
|
|
||||||
$(
|
$(
|
||||||
let $tail = if let Ok(x) = $tail.parse::<$tail>() {
|
let $tail = if let Ok(x) = $tail.parse::<$tail>() {
|
||||||
x
|
x
|
||||||
} else {
|
} else {
|
||||||
return Err(InvalidUrlParam::new::<$tail>().boxed());
|
return Err(InvalidUrlParam::new::<$tail>().into_response());
|
||||||
};
|
};
|
||||||
)*
|
)*
|
||||||
|
|
||||||
Ok(UrlParams(($head, $($tail,)*)))
|
Ok(UrlParams(($head, $($tail,)*)))
|
||||||
} else {
|
} else {
|
||||||
return Err(MissingRouteParams(()).boxed())
|
return Err(MissingRouteParams(()).into_response())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -420,49 +393,6 @@ macro_rules! impl_parse_url {
|
||||||
|
|
||||||
impl_parse_url!(T1, T2, T3, T4, T5, T6);
|
impl_parse_url!(T1, T2, T3, T4, T5, T6);
|
||||||
|
|
||||||
impl<T1> UrlParams<(T1,)> {
|
|
||||||
pub fn into_inner(self) -> T1 {
|
|
||||||
(self.0).0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T1, T2> UrlParams<(T1, T2)> {
|
|
||||||
pub fn into_inner(self) -> (T1, T2) {
|
|
||||||
((self.0).0, (self.0).1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T1, T2, T3> UrlParams<(T1, T2, T3)> {
|
|
||||||
pub fn into_inner(self) -> (T1, T2, T3) {
|
|
||||||
((self.0).0, (self.0).1, (self.0).2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T1, T2, T3, T4> UrlParams<(T1, T2, T3, T4)> {
|
|
||||||
pub fn into_inner(self) -> (T1, T2, T3, T4) {
|
|
||||||
((self.0).0, (self.0).1, (self.0).2, (self.0).3)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T1, T2, T3, T4, T5> UrlParams<(T1, T2, T3, T4, T5)> {
|
|
||||||
pub fn into_inner(self) -> (T1, T2, T3, T4, T5) {
|
|
||||||
((self.0).0, (self.0).1, (self.0).2, (self.0).3, (self.0).4)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T1, T2, T3, T4, T5, T6> UrlParams<(T1, T2, T3, T4, T5, T6)> {
|
|
||||||
pub fn into_inner(self) -> (T1, T2, T3, T4, T5, T6) {
|
|
||||||
(
|
|
||||||
(self.0).0,
|
|
||||||
(self.0).1,
|
|
||||||
(self.0).2,
|
|
||||||
(self.0).3,
|
|
||||||
(self.0).4,
|
|
||||||
(self.0).5,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
define_rejection! {
|
define_rejection! {
|
||||||
#[status = INTERNAL_SERVER_ERROR]
|
#[status = INTERNAL_SERVER_ERROR]
|
||||||
#[body = "Cannot have two request body extractors for a single handler"]
|
#[body = "Cannot have two request body extractors for a single handler"]
|
||||||
|
|
|
@ -22,6 +22,9 @@ pub mod handler;
|
||||||
pub mod response;
|
pub mod response;
|
||||||
pub mod routing;
|
pub mod routing;
|
||||||
|
|
||||||
|
pub use tower_http::add_extension::{AddExtension, AddExtensionLayer};
|
||||||
|
pub use async_trait::async_trait;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
|
|
|
@ -5,17 +5,9 @@ use serde::Serialize;
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use tower::util::Either;
|
use tower::util::Either;
|
||||||
|
|
||||||
|
// TODO(david): can we change this to not be generic over the body and just use hyper::Body?
|
||||||
pub trait IntoResponse<B> {
|
pub trait IntoResponse<B> {
|
||||||
fn into_response(self) -> Response<B>;
|
fn into_response(self) -> Response<B>;
|
||||||
|
|
||||||
// TODO(david): remove this an return return `Response<B>` instead. That is what this method
|
|
||||||
// does anyway.
|
|
||||||
fn boxed(self) -> BoxIntoResponse<B>
|
|
||||||
where
|
|
||||||
Self: Sized + 'static,
|
|
||||||
{
|
|
||||||
BoxIntoResponse(self.into_response())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B> IntoResponse<B> for ()
|
impl<B> IntoResponse<B> for ()
|
||||||
|
@ -171,14 +163,6 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct BoxIntoResponse<B>(Response<B>);
|
|
||||||
|
|
||||||
impl<B> IntoResponse<B> for BoxIntoResponse<B> {
|
|
||||||
fn into_response(self) -> Response<B> {
|
|
||||||
self.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B> IntoResponse<B> for StatusCode
|
impl<B> IntoResponse<B> for StatusCode
|
||||||
where
|
where
|
||||||
B: Default,
|
B: Default,
|
||||||
|
|
|
@ -51,7 +51,7 @@ async fn deserialize_body() {
|
||||||
|
|
||||||
let app = app()
|
let app = app()
|
||||||
.at("/")
|
.at("/")
|
||||||
.post(|_: Request<Body>, input: extract::Json<Input>| async { input.into_inner().foo })
|
.post(|_: Request<Body>, input: extract::Json<Input>| async { input.0.foo })
|
||||||
.into_service();
|
.into_service();
|
||||||
|
|
||||||
let addr = run_in_background(app).await;
|
let addr = run_in_background(app).await;
|
||||||
|
@ -78,8 +78,7 @@ async fn consume_body_to_json_requires_json_content_type() {
|
||||||
let app = app()
|
let app = app()
|
||||||
.at("/")
|
.at("/")
|
||||||
.post(|_: Request<Body>, input: extract::Json<Input>| async {
|
.post(|_: Request<Body>, input: extract::Json<Input>| async {
|
||||||
let input = input.into_inner();
|
input.0.foo
|
||||||
input.foo
|
|
||||||
})
|
})
|
||||||
.into_service();
|
.into_service();
|
||||||
|
|
||||||
|
@ -216,7 +215,7 @@ async fn extracting_url_params() {
|
||||||
.at("/users/:id")
|
.at("/users/:id")
|
||||||
.get(
|
.get(
|
||||||
|_: Request<Body>, params: extract::UrlParams<(i32,)>| async move {
|
|_: Request<Body>, params: extract::UrlParams<(i32,)>| async move {
|
||||||
let id = params.into_inner();
|
let (id,) = params.0;
|
||||||
assert_eq!(id, 42);
|
assert_eq!(id, 42);
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue