mirror of
https://github.com/tokio-rs/axum.git
synced 2025-04-26 13:56:22 +02:00
add OverrideAllStatusCodes
This commit is contained in:
parent
a2b6590d62
commit
4b5b1e74e4
3 changed files with 108 additions and 8 deletions
|
@ -1,4 +1,6 @@
|
|||
use super::{IntoResponseFailed, IntoResponseParts, Response, ResponseParts};
|
||||
use super::{
|
||||
IntoResponseFailed, IntoResponseParts, OverrideAllStatusCodes, Response, ResponseParts,
|
||||
};
|
||||
use crate::{body::Body, BoxError};
|
||||
use bytes::{buf::Chain, Buf, Bytes, BytesMut};
|
||||
use http::{
|
||||
|
@ -448,6 +450,26 @@ macro_rules! impl_into_response {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
impl<R, $($ty,)*> IntoResponse for (OverrideAllStatusCodes, $($ty),*, R)
|
||||
where
|
||||
$( $ty: IntoResponseParts, )*
|
||||
R: IntoResponse,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
let (status, $($ty),*, res) = self;
|
||||
|
||||
let res = res.into_response();
|
||||
let parts = ResponseParts { res };
|
||||
let parts = match ($($ty,)*).into_response_parts(parts) {
|
||||
Ok(parts) => parts,
|
||||
Err(err) => return err.into_response(),
|
||||
};
|
||||
|
||||
(status, parts.res).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
impl<R, $($ty,)*> IntoResponse for (http::response::Parts, $($ty),*, R)
|
||||
where
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use http::StatusCode;
|
||||
|
||||
use crate::body::Body;
|
||||
|
||||
mod append_headers;
|
||||
|
@ -156,3 +158,29 @@ impl IntoResponseParts for IntoResponseFailed {
|
|||
/// ```
|
||||
#[allow(dead_code)]
|
||||
fn into_response_failed_doesnt_impl_into_response() {}
|
||||
|
||||
/// Override all status codes regardless if [`IntoResponseFailed`] is used or not.
|
||||
///
|
||||
/// See the docs for [`IntoResponseFailed`] for more details.
|
||||
#[derive(Debug, Copy, Clone, Default)]
|
||||
pub struct OverrideAllStatusCodes(pub StatusCode);
|
||||
|
||||
impl IntoResponse for OverrideAllStatusCodes {
|
||||
fn into_response(self) -> Response {
|
||||
let mut res = ().into_response();
|
||||
*res.status_mut() = self.0;
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> IntoResponse for (OverrideAllStatusCodes, R)
|
||||
where
|
||||
R: IntoResponse,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
let (OverrideAllStatusCodes(status), res) = self;
|
||||
let mut res = res.into_response();
|
||||
*res.status_mut() = status;
|
||||
res
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,6 +66,7 @@ mod tests {
|
|||
use crate::test_helpers::*;
|
||||
use crate::Json;
|
||||
use crate::{routing::get, Router};
|
||||
use axum_core::response::OverrideAllStatusCodes;
|
||||
use axum_core::response::{
|
||||
IntoResponse, IntoResponseFailed, IntoResponseParts, Response, ResponseParts,
|
||||
};
|
||||
|
@ -446,19 +447,68 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn force_overriding_status_code() {
|
||||
assert_eq!(
|
||||
OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT)
|
||||
.into_response()
|
||||
.status(),
|
||||
StatusCode::IM_A_TEAPOT
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
(OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT),)
|
||||
.into_response()
|
||||
.status(),
|
||||
StatusCode::IM_A_TEAPOT
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
(OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT), ())
|
||||
.into_response()
|
||||
.status(),
|
||||
StatusCode::IM_A_TEAPOT
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
(
|
||||
OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT),
|
||||
IntoResponseFailed,
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
.into_response()
|
||||
.status(),
|
||||
StatusCode::IM_A_TEAPOT
|
||||
);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn status_code_tuple_doesnt_override_error_json() {
|
||||
let app = Router::new().route(
|
||||
"/",
|
||||
get(|| async {
|
||||
let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]);
|
||||
(StatusCode::IM_A_TEAPOT, Json(not_json_compatible))
|
||||
}),
|
||||
);
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/",
|
||||
get(|| async {
|
||||
let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]);
|
||||
(StatusCode::IM_A_TEAPOT, Json(not_json_compatible))
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/two",
|
||||
get(|| async {
|
||||
let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]);
|
||||
(
|
||||
OverrideAllStatusCodes(StatusCode::IM_A_TEAPOT),
|
||||
Json(not_json_compatible),
|
||||
)
|
||||
}),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/").send().await;
|
||||
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
let res = client.get("/two").send().await;
|
||||
assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue