mirror of
https://github.com/tokio-rs/axum.git
synced 2025-04-26 13:56:22 +02:00
wip
This commit is contained in:
parent
5d388c8d91
commit
a2b6590d62
5 changed files with 310 additions and 32 deletions
|
@ -1,4 +1,4 @@
|
|||
use super::{IntoResponseParts, Response, ResponseParts};
|
||||
use super::{IntoResponseFailed, IntoResponseParts, Response, ResponseParts};
|
||||
use crate::{body::Body, BoxError};
|
||||
use bytes::{buf::Chain, Buf, Bytes, BytesMut};
|
||||
use http::{
|
||||
|
@ -328,7 +328,9 @@ where
|
|||
{
|
||||
fn into_response(self) -> Response {
|
||||
let mut res = self.1.into_response();
|
||||
*res.status_mut() = self.0;
|
||||
if res.extensions().get::<IntoResponseFailed>().is_none() {
|
||||
*res.status_mut() = self.0;
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
@ -404,18 +406,16 @@ macro_rules! impl_into_response {
|
|||
let ($($ty),*, res) = self;
|
||||
|
||||
let res = res.into_response();
|
||||
let parts = ResponseParts { res };
|
||||
|
||||
$(
|
||||
let parts = match $ty.into_response_parts(parts) {
|
||||
if res.extensions().get::<IntoResponseFailed>().is_none() {
|
||||
let parts = ResponseParts { res };
|
||||
let parts = match ($($ty,)*).into_response_parts(parts) {
|
||||
Ok(parts) => parts,
|
||||
Err(err) => {
|
||||
return err.into_response();
|
||||
}
|
||||
Err(err) => return err.into_response(),
|
||||
};
|
||||
)*
|
||||
|
||||
parts.res
|
||||
parts.res
|
||||
} else {
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -429,18 +429,22 @@ macro_rules! impl_into_response {
|
|||
let (status, $($ty),*, res) = self;
|
||||
|
||||
let res = res.into_response();
|
||||
let parts = ResponseParts { res };
|
||||
|
||||
$(
|
||||
let parts = match $ty.into_response_parts(parts) {
|
||||
if res.extensions().get::<IntoResponseFailed>().is_none() {
|
||||
let parts = ResponseParts { res };
|
||||
let mut parts = match ($($ty,)*).into_response_parts(parts) {
|
||||
Ok(parts) => parts,
|
||||
Err(err) => {
|
||||
return err.into_response();
|
||||
}
|
||||
Err(err) => return err.into_response(),
|
||||
};
|
||||
)*
|
||||
|
||||
(status, parts.res).into_response()
|
||||
// Don't call `(status, parts.res).into_response()` since that checks for
|
||||
// `IntoResponseFailed` and skips setting the status. We've already done that
|
||||
// check here so overriding the status is required if returning
|
||||
// `(IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR)`
|
||||
*parts.res.status_mut() = status;
|
||||
parts.res
|
||||
} else {
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -454,17 +458,22 @@ macro_rules! impl_into_response {
|
|||
let (outer_parts, $($ty),*, res) = self;
|
||||
|
||||
let res = res.into_response();
|
||||
let parts = ResponseParts { res };
|
||||
$(
|
||||
let parts = match $ty.into_response_parts(parts) {
|
||||
if res.extensions().get::<IntoResponseFailed>().is_none() {
|
||||
let parts = ResponseParts { res };
|
||||
let mut parts = match ($($ty,)*).into_response_parts(parts) {
|
||||
Ok(parts) => parts,
|
||||
Err(err) => {
|
||||
return err.into_response();
|
||||
}
|
||||
Err(err) => return err.into_response(),
|
||||
};
|
||||
)*
|
||||
|
||||
(outer_parts, parts.res).into_response()
|
||||
// Don't call `(outer_parts, parts.res).into_response()` for the same reason we
|
||||
// don't call `(status, parts.res).into_response()` in the above impl.
|
||||
*parts.res.status_mut() = outer_parts.status;
|
||||
parts.res.headers_mut().extend(outer_parts.headers);
|
||||
parts.res.extensions_mut().extend(outer_parts.extensions);
|
||||
parts.res
|
||||
} else {
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -237,7 +237,9 @@ macro_rules! impl_into_response_parts {
|
|||
let res = match $ty.into_response_parts(res) {
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
return Err(err.into_response());
|
||||
let mut err_res = err.into_response();
|
||||
err_res.extensions_mut().insert(super::IntoResponseFailed);
|
||||
return Err(err_res);
|
||||
}
|
||||
};
|
||||
)*
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
//!
|
||||
//! [`axum::response`]: https://docs.rs/axum/0.7/axum/response/index.html
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use crate::body::Body;
|
||||
|
||||
mod append_headers;
|
||||
|
@ -127,3 +129,30 @@ where
|
|||
Self(value.into_response())
|
||||
}
|
||||
}
|
||||
|
||||
/// ```
|
||||
/// todo!();
|
||||
/// ```
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct IntoResponseFailed;
|
||||
|
||||
impl IntoResponseParts for IntoResponseFailed {
|
||||
type Error = Infallible;
|
||||
|
||||
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
|
||||
res.extensions_mut().insert(self);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
/// Not sure it makes sense to return `IntoResponseFailed` as the whole response. You should
|
||||
/// probably at least combine it with a status code.
|
||||
///
|
||||
/// ```compile_fail
|
||||
/// fn foo()
|
||||
/// where
|
||||
/// axum_core::response::IntoResponseFailed: axum_core::response::IntoResponse,
|
||||
/// {}
|
||||
/// ```
|
||||
#[allow(dead_code)]
|
||||
fn into_response_failed_doesnt_impl_into_response() {}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::extract::Request;
|
||||
use crate::extract::{rejection::*, FromRequest};
|
||||
use async_trait::async_trait;
|
||||
use axum_core::response::{IntoResponse, Response};
|
||||
use axum_core::response::{IntoResponse, IntoResponseFailed, Response};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use http::{
|
||||
header::{self, HeaderMap, HeaderValue},
|
||||
|
@ -202,6 +202,7 @@ where
|
|||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
|
||||
)],
|
||||
IntoResponseFailed,
|
||||
err.to_string(),
|
||||
)
|
||||
.into_response(),
|
||||
|
|
|
@ -63,10 +63,15 @@ impl<T> From<T> for Html<T> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::extract::Extension;
|
||||
use crate::test_helpers::*;
|
||||
use crate::Json;
|
||||
use crate::{routing::get, Router};
|
||||
use axum_core::response::IntoResponse;
|
||||
use axum_core::response::{
|
||||
IntoResponse, IntoResponseFailed, IntoResponseParts, Response, ResponseParts,
|
||||
};
|
||||
use http::HeaderMap;
|
||||
use http::{StatusCode, Uri};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// just needs to compile
|
||||
#[allow(dead_code)]
|
||||
|
@ -224,4 +229,236 @@ mod tests {
|
|||
.route("/", get(header_array_extension_body))
|
||||
.route("/", get(header_array_extension_mixed_body));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_code_tuple_doesnt_override_error() {
|
||||
// sanity check where there is just one status code
|
||||
assert_eq!(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response().status(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
);
|
||||
assert_eq!(
|
||||
(StatusCode::INTERNAL_SERVER_ERROR,)
|
||||
.into_response()
|
||||
.status(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
);
|
||||
|
||||
// non-5xx status should be changed
|
||||
assert_eq!(
|
||||
(StatusCode::SEE_OTHER, StatusCode::NO_CONTENT)
|
||||
.into_response()
|
||||
.status(),
|
||||
StatusCode::SEE_OTHER
|
||||
);
|
||||
let res = (
|
||||
StatusCode::SEE_OTHER,
|
||||
[("location", "foo")],
|
||||
StatusCode::NO_CONTENT,
|
||||
)
|
||||
.into_response();
|
||||
assert_eq!(res.status(), StatusCode::SEE_OTHER);
|
||||
assert_eq!(res.headers()["location"], "foo");
|
||||
|
||||
// 5xx status codes are also changed
|
||||
assert_eq!(
|
||||
(StatusCode::SEE_OTHER, StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.into_response()
|
||||
.status(),
|
||||
StatusCode::SEE_OTHER
|
||||
);
|
||||
let res = (
|
||||
StatusCode::SEE_OTHER,
|
||||
[("location", "foo")],
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
.into_response();
|
||||
assert_eq!(res.status(), StatusCode::SEE_OTHER);
|
||||
assert_eq!(res.headers()["location"], "foo");
|
||||
|
||||
// the status is not changed if `IntoResponseFailed` is used
|
||||
assert_eq!(
|
||||
(
|
||||
StatusCode::SEE_OTHER,
|
||||
(IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR)
|
||||
)
|
||||
.into_response()
|
||||
.status(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
);
|
||||
let res = (
|
||||
StatusCode::SEE_OTHER,
|
||||
[("location", "foo")],
|
||||
(IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
.into_response();
|
||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
assert!(res.headers().get("location").is_none());
|
||||
|
||||
// response parts from the inner response do run
|
||||
let res = (
|
||||
// with status override
|
||||
StatusCode::SEE_OTHER,
|
||||
[("location", "foo")],
|
||||
(
|
||||
[("x-bar", "bar")],
|
||||
IntoResponseFailed,
|
||||
[("x-foo", "foo")],
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
)
|
||||
.into_response();
|
||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
assert!(res.headers().get("location").is_none());
|
||||
assert_eq!(res.headers()["x-foo"], "foo");
|
||||
assert_eq!(res.headers()["x-bar"], "bar");
|
||||
|
||||
let res = (
|
||||
// without status override
|
||||
[("location", "foo")],
|
||||
(
|
||||
[("x-bar", "bar")],
|
||||
IntoResponseFailed,
|
||||
[("x-foo", "foo")],
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
)
|
||||
.into_response();
|
||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
assert!(res.headers().get("location").is_none());
|
||||
assert_eq!(res.headers()["x-foo"], "foo");
|
||||
assert_eq!(res.headers()["x-bar"], "bar");
|
||||
|
||||
// (Parts, ...)
|
||||
let res = (
|
||||
Response::new(()).into_parts().0,
|
||||
[("location", "foo")],
|
||||
(
|
||||
[("x-bar", "bar")],
|
||||
IntoResponseFailed,
|
||||
[("x-foo", "foo")],
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
)
|
||||
.into_response();
|
||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
assert!(res.headers().get("location").is_none());
|
||||
assert_eq!(res.headers()["x-foo"], "foo");
|
||||
assert_eq!(res.headers()["x-bar"], "bar");
|
||||
|
||||
// (Response<()>, ...)
|
||||
let res = (
|
||||
Response::new(()),
|
||||
[("location", "foo")],
|
||||
(
|
||||
[("x-bar", "bar")],
|
||||
IntoResponseFailed,
|
||||
[("x-foo", "foo")],
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
)
|
||||
.into_response();
|
||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
assert!(res.headers().get("location").is_none());
|
||||
assert_eq!(res.headers()["x-foo"], "foo");
|
||||
assert_eq!(res.headers()["x-bar"], "bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn into_response_parts_failing_sets_extension() {
|
||||
struct Fail;
|
||||
|
||||
impl IntoResponseParts for Fail {
|
||||
type Error = ();
|
||||
|
||||
fn into_response_parts(
|
||||
self,
|
||||
_res: ResponseParts,
|
||||
) -> Result<ResponseParts, Self::Error> {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for Fail {
|
||||
fn into_response(self) -> Response {
|
||||
(self, ()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
assert!(Fail
|
||||
.into_response()
|
||||
.extensions()
|
||||
.get::<IntoResponseFailed>()
|
||||
.is_some());
|
||||
|
||||
assert!((StatusCode::INTERNAL_SERVER_ERROR, Fail, ())
|
||||
.into_response()
|
||||
.extensions()
|
||||
.get::<IntoResponseFailed>()
|
||||
.is_some());
|
||||
|
||||
assert!((Response::new(()).into_parts().0, Fail, ())
|
||||
.into_response()
|
||||
.extensions()
|
||||
.get::<IntoResponseFailed>()
|
||||
.is_some());
|
||||
|
||||
assert!((Response::new(()), Fail, ())
|
||||
.into_response()
|
||||
.extensions()
|
||||
.get::<IntoResponseFailed>()
|
||||
.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn doenst_override_status_code_when_using_into_response_failed_at_same_level() {
|
||||
assert_eq!(
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, IntoResponseFailed, ())
|
||||
.into_response()
|
||||
.status(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Thing;
|
||||
|
||||
let res = (
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header("x-foo", "foo")
|
||||
.extension(Thing)
|
||||
.body(())
|
||||
.unwrap()
|
||||
.into_parts()
|
||||
.0,
|
||||
IntoResponseFailed,
|
||||
(),
|
||||
)
|
||||
.into_response();
|
||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR,);
|
||||
assert_eq!(res.headers()["x-foo"], "foo");
|
||||
assert!(res.extensions().get::<Thing>().is_some());
|
||||
|
||||
// just a sanity check
|
||||
assert_eq!(
|
||||
(IntoResponseFailed, ()).into_response().status(),
|
||||
StatusCode::OK,
|
||||
);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn status_code_tuple_doesnt_override_error_json() {
|
||||
let app = Router::new().route(
|
||||
"/",
|
||||
get(|| async {
|
||||
let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]);
|
||||
(StatusCode::IM_A_TEAPOT, Json(not_json_compatible))
|
||||
}),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/").send().await;
|
||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue