1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-26 13:56:22 +02:00

add OverrideAllStatusCodes

This commit is contained in:
David Pedersen 2023-12-30 17:44:12 +01:00
parent a2b6590d62
commit 4b5b1e74e4
3 changed files with 108 additions and 8 deletions
axum-core/src/response
axum/src/response

View file

@ -1,4 +1,6 @@
use super::{IntoResponseFailed, IntoResponseParts, Response, ResponseParts};
use super::{
IntoResponseFailed, IntoResponseParts, OverrideAllStatusCodes, Response, ResponseParts,
};
use crate::{body::Body, BoxError};
use bytes::{buf::Chain, Buf, Bytes, BytesMut};
use http::{
@ -448,6 +450,26 @@ macro_rules! impl_into_response {
}
}
#[allow(non_snake_case)]
impl<R, $($ty,)*> IntoResponse for (OverrideAllStatusCodes, $($ty),*, R)
where
$( $ty: IntoResponseParts, )*
R: IntoResponse,
{
fn into_response(self) -> Response {
let (status, $($ty),*, res) = self;
let res = res.into_response();
let parts = ResponseParts { res };
let parts = match ($($ty,)*).into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => return err.into_response(),
};
(status, parts.res).into_response()
}
}
#[allow(non_snake_case)]
impl<R, $($ty,)*> IntoResponse for (http::response::Parts, $($ty),*, R)
where

View file

@ -6,6 +6,8 @@
use std::convert::Infallible;
use http::StatusCode;
use crate::body::Body;
mod append_headers;
@ -156,3 +158,29 @@ impl IntoResponseParts for IntoResponseFailed {
/// ```
#[allow(dead_code)]
fn into_response_failed_doesnt_impl_into_response() {}
/// Override all status codes regardless if [`IntoResponseFailed`] is used or not.
///
/// See the docs for [`IntoResponseFailed`] for more details.
#[derive(Debug, Copy, Clone, Default)]
pub struct OverrideAllStatusCodes(pub StatusCode);
impl IntoResponse for OverrideAllStatusCodes {
fn into_response(self) -> Response {
let mut res = ().into_response();
*res.status_mut() = self.0;
res
}
}
impl<R> IntoResponse for (OverrideAllStatusCodes, R)
where
R: IntoResponse,
{
fn into_response(self) -> Response {
let (OverrideAllStatusCodes(status), res) = self;
let mut res = res.into_response();
*res.status_mut() = status;
res
}
}

View file

@ -66,6 +66,7 @@ mod tests {
use crate::test_helpers::*;
use crate::Json;
use crate::{routing::get, Router};
use axum_core::response::OverrideAllStatusCodes;
use axum_core::response::{
IntoResponse, IntoResponseFailed, IntoResponseParts, Response, ResponseParts,
};
@ -446,19 +447,68 @@ mod tests {
);
}
#[test]
fn force_overriding_status_code() {
assert_eq!(
OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT)
.into_response()
.status(),
StatusCode::IM_A_TEAPOT
);
assert_eq!(
(OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT),)
.into_response()
.status(),
StatusCode::IM_A_TEAPOT
);
assert_eq!(
(OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT), ())
.into_response()
.status(),
StatusCode::IM_A_TEAPOT
);
assert_eq!(
(
OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT),
IntoResponseFailed,
StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response()
.status(),
StatusCode::IM_A_TEAPOT
);
}
#[crate::test]
async fn status_code_tuple_doesnt_override_error_json() {
let app = Router::new().route(
"/",
get(|| async {
let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]);
(StatusCode::IM_A_TEAPOT, Json(not_json_compatible))
}),
);
let app = Router::new()
.route(
"/",
get(|| async {
let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]);
(StatusCode::IM_A_TEAPOT, Json(not_json_compatible))
}),
)
.route(
"/two",
get(|| async {
let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]);
(
OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT),
Json(not_json_compatible),
)
}),
);
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
let res = client.get("/two").send().await;
assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
}
}