1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-26 13:56:22 +02:00
This commit is contained in:
David Pedersen 2023-12-30 16:37:38 +01:00
parent 5d388c8d91
commit a2b6590d62
5 changed files with 310 additions and 32 deletions
axum-core/src/response
axum/src

View file

@ -1,4 +1,4 @@
use super::{IntoResponseParts, Response, ResponseParts};
use super::{IntoResponseFailed, IntoResponseParts, Response, ResponseParts};
use crate::{body::Body, BoxError};
use bytes::{buf::Chain, Buf, Bytes, BytesMut};
use http::{
@ -328,7 +328,9 @@ where
{
fn into_response(self) -> Response {
let mut res = self.1.into_response();
*res.status_mut() = self.0;
if res.extensions().get::<IntoResponseFailed>().is_none() {
*res.status_mut() = self.0;
}
res
}
}
@ -404,18 +406,16 @@ macro_rules! impl_into_response {
let ($($ty),*, res) = self;
let res = res.into_response();
let parts = ResponseParts { res };
$(
let parts = match $ty.into_response_parts(parts) {
if res.extensions().get::<IntoResponseFailed>().is_none() {
let parts = ResponseParts { res };
let parts = match ($($ty,)*).into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => {
return err.into_response();
}
Err(err) => return err.into_response(),
};
)*
parts.res
parts.res
} else {
res
}
}
}
@ -429,18 +429,22 @@ macro_rules! impl_into_response {
let (status, $($ty),*, res) = self;
let res = res.into_response();
let parts = ResponseParts { res };
$(
let parts = match $ty.into_response_parts(parts) {
if res.extensions().get::<IntoResponseFailed>().is_none() {
let parts = ResponseParts { res };
let mut parts = match ($($ty,)*).into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => {
return err.into_response();
}
Err(err) => return err.into_response(),
};
)*
(status, parts.res).into_response()
// Don't call `(status, parts.res).into_response()` since that checks for
// `IntoResponseFailed` and skips setting the status. We've already done that
// check here so overriding the status is required if returning
// `(IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR)`
*parts.res.status_mut() = status;
parts.res
} else {
res
}
}
}
@ -454,17 +458,22 @@ macro_rules! impl_into_response {
let (outer_parts, $($ty),*, res) = self;
let res = res.into_response();
let parts = ResponseParts { res };
$(
let parts = match $ty.into_response_parts(parts) {
if res.extensions().get::<IntoResponseFailed>().is_none() {
let parts = ResponseParts { res };
let mut parts = match ($($ty,)*).into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => {
return err.into_response();
}
Err(err) => return err.into_response(),
};
)*
(outer_parts, parts.res).into_response()
// Don't call `(outer_parts, parts.res).into_response()` for the same reason we
// don't call `(status, parts.res).into_response()` in the above impl.
*parts.res.status_mut() = outer_parts.status;
parts.res.headers_mut().extend(outer_parts.headers);
parts.res.extensions_mut().extend(outer_parts.extensions);
parts.res
} else {
res
}
}
}

View file

@ -237,7 +237,9 @@ macro_rules! impl_into_response_parts {
let res = match $ty.into_response_parts(res) {
Ok(res) => res,
Err(err) => {
return Err(err.into_response());
let mut err_res = err.into_response();
err_res.extensions_mut().insert(super::IntoResponseFailed);
return Err(err_res);
}
};
)*

View file

@ -4,6 +4,8 @@
//!
//! [`axum::response`]: https://docs.rs/axum/0.7/axum/response/index.html
use std::convert::Infallible;
use crate::body::Body;
mod append_headers;
@ -127,3 +129,30 @@ where
Self(value.into_response())
}
}
/// ```
/// todo!();
/// ```
#[derive(Copy, Clone, Debug)]
pub struct IntoResponseFailed;
impl IntoResponseParts for IntoResponseFailed {
type Error = Infallible;
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
res.extensions_mut().insert(self);
Ok(res)
}
}
/// Not sure it makes sense to return `IntoResponseFailed` as the whole response. You should
/// probably at least combine it with a status code.
///
/// ```compile_fail
/// fn foo()
/// where
/// axum_core::response::IntoResponseFailed: axum_core::response::IntoResponse,
/// {}
/// ```
#[allow(dead_code)]
fn into_response_failed_doesnt_impl_into_response() {}

View file

@ -1,7 +1,7 @@
use crate::extract::Request;
use crate::extract::{rejection::*, FromRequest};
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
use axum_core::response::{IntoResponse, IntoResponseFailed, Response};
use bytes::{BufMut, Bytes, BytesMut};
use http::{
header::{self, HeaderMap, HeaderValue},
@ -202,6 +202,7 @@ where
header::CONTENT_TYPE,
HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
)],
IntoResponseFailed,
err.to_string(),
)
.into_response(),

View file

@ -63,10 +63,15 @@ impl<T> From<T> for Html<T> {
#[cfg(test)]
mod tests {
use crate::extract::Extension;
use crate::test_helpers::*;
use crate::Json;
use crate::{routing::get, Router};
use axum_core::response::IntoResponse;
use axum_core::response::{
IntoResponse, IntoResponseFailed, IntoResponseParts, Response, ResponseParts,
};
use http::HeaderMap;
use http::{StatusCode, Uri};
use std::collections::HashMap;
// just needs to compile
#[allow(dead_code)]
@ -224,4 +229,236 @@ mod tests {
.route("/", get(header_array_extension_body))
.route("/", get(header_array_extension_mixed_body));
}
#[test]
fn status_code_tuple_doesnt_override_error() {
// sanity check where there is just one status code
assert_eq!(
StatusCode::INTERNAL_SERVER_ERROR.into_response().status(),
StatusCode::INTERNAL_SERVER_ERROR
);
assert_eq!(
(StatusCode::INTERNAL_SERVER_ERROR,)
.into_response()
.status(),
StatusCode::INTERNAL_SERVER_ERROR
);
// non-5xx status should be changed
assert_eq!(
(StatusCode::SEE_OTHER, StatusCode::NO_CONTENT)
.into_response()
.status(),
StatusCode::SEE_OTHER
);
let res = (
StatusCode::SEE_OTHER,
[("location", "foo")],
StatusCode::NO_CONTENT,
)
.into_response();
assert_eq!(res.status(), StatusCode::SEE_OTHER);
assert_eq!(res.headers()["location"], "foo");
// 5xx status codes are also changed
assert_eq!(
(StatusCode::SEE_OTHER, StatusCode::INTERNAL_SERVER_ERROR)
.into_response()
.status(),
StatusCode::SEE_OTHER
);
let res = (
StatusCode::SEE_OTHER,
[("location", "foo")],
StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response();
assert_eq!(res.status(), StatusCode::SEE_OTHER);
assert_eq!(res.headers()["location"], "foo");
// the status is not changed if `IntoResponseFailed` is used
assert_eq!(
(
StatusCode::SEE_OTHER,
(IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR)
)
.into_response()
.status(),
StatusCode::INTERNAL_SERVER_ERROR
);
let res = (
StatusCode::SEE_OTHER,
[("location", "foo")],
(IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().get("location").is_none());
// response parts from the inner response do run
let res = (
// with status override
StatusCode::SEE_OTHER,
[("location", "foo")],
(
[("x-bar", "bar")],
IntoResponseFailed,
[("x-foo", "foo")],
StatusCode::INTERNAL_SERVER_ERROR,
),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().get("location").is_none());
assert_eq!(res.headers()["x-foo"], "foo");
assert_eq!(res.headers()["x-bar"], "bar");
let res = (
// without status override
[("location", "foo")],
(
[("x-bar", "bar")],
IntoResponseFailed,
[("x-foo", "foo")],
StatusCode::INTERNAL_SERVER_ERROR,
),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().get("location").is_none());
assert_eq!(res.headers()["x-foo"], "foo");
assert_eq!(res.headers()["x-bar"], "bar");
// (Parts, ...)
let res = (
Response::new(()).into_parts().0,
[("location", "foo")],
(
[("x-bar", "bar")],
IntoResponseFailed,
[("x-foo", "foo")],
StatusCode::INTERNAL_SERVER_ERROR,
),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().get("location").is_none());
assert_eq!(res.headers()["x-foo"], "foo");
assert_eq!(res.headers()["x-bar"], "bar");
// (Response<()>, ...)
let res = (
Response::new(()),
[("location", "foo")],
(
[("x-bar", "bar")],
IntoResponseFailed,
[("x-foo", "foo")],
StatusCode::INTERNAL_SERVER_ERROR,
),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().get("location").is_none());
assert_eq!(res.headers()["x-foo"], "foo");
assert_eq!(res.headers()["x-bar"], "bar");
}
#[test]
fn into_response_parts_failing_sets_extension() {
struct Fail;
impl IntoResponseParts for Fail {
type Error = ();
fn into_response_parts(
self,
_res: ResponseParts,
) -> Result<ResponseParts, Self::Error> {
Err(())
}
}
impl IntoResponse for Fail {
fn into_response(self) -> Response {
(self, ()).into_response()
}
}
assert!(Fail
.into_response()
.extensions()
.get::<IntoResponseFailed>()
.is_some());
assert!((StatusCode::INTERNAL_SERVER_ERROR, Fail, ())
.into_response()
.extensions()
.get::<IntoResponseFailed>()
.is_some());
assert!((Response::new(()).into_parts().0, Fail, ())
.into_response()
.extensions()
.get::<IntoResponseFailed>()
.is_some());
assert!((Response::new(()), Fail, ())
.into_response()
.extensions()
.get::<IntoResponseFailed>()
.is_some());
}
#[test]
fn doenst_override_status_code_when_using_into_response_failed_at_same_level() {
assert_eq!(
(StatusCode::INTERNAL_SERVER_ERROR, IntoResponseFailed, ())
.into_response()
.status(),
StatusCode::INTERNAL_SERVER_ERROR,
);
#[derive(Clone)]
struct Thing;
let res = (
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("x-foo", "foo")
.extension(Thing)
.body(())
.unwrap()
.into_parts()
.0,
IntoResponseFailed,
(),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR,);
assert_eq!(res.headers()["x-foo"], "foo");
assert!(res.extensions().get::<Thing>().is_some());
// just a sanity check
assert_eq!(
(IntoResponseFailed, ()).into_response().status(),
StatusCode::OK,
);
}
#[crate::test]
async fn status_code_tuple_doesnt_override_error_json() {
let app = Router::new().route(
"/",
get(|| async {
let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]);
(StatusCode::IM_A_TEAPOT, Json(not_json_compatible))
}),
);
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
}