mirror of
https://github.com/tokio-rs/axum.git
synced 2025-02-19 11:39:53 +01:00
Refactor TypedHeader
extractor (#189)
I should use `HeaderMapExt::typed_try_get` rather than implementing it manually.
This commit is contained in:
parent
48afd30491
commit
be7e9e9bc6
5 changed files with 107 additions and 78 deletions
src
|
@ -341,40 +341,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Rejection used for [`TypedHeader`](super::TypedHeader).
|
||||
#[cfg(feature = "headers")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
||||
#[derive(Debug)]
|
||||
pub struct TypedHeaderRejection {
|
||||
pub(super) name: &'static http::header::HeaderName,
|
||||
pub(super) err: headers::Error,
|
||||
}
|
||||
|
||||
#[cfg(feature = "headers")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
||||
impl IntoResponse for TypedHeaderRejection {
|
||||
type Body = Full<Bytes>;
|
||||
type BodyError = Infallible;
|
||||
|
||||
fn into_response(self) -> http::Response<Self::Body> {
|
||||
let mut res = self.to_string().into_response();
|
||||
*res.status_mut() = http::StatusCode::BAD_REQUEST;
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "headers")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
||||
impl std::fmt::Display for TypedHeaderRejection {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{} ({})", self.err, self.name)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "headers")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
||||
impl std::error::Error for TypedHeaderRejection {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
Some(&self.err)
|
||||
}
|
||||
}
|
||||
pub use super::typed_header::TypedHeaderRejection;
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
use super::{rejection::TypedHeaderRejection, FromRequest, RequestParts};
|
||||
use super::{FromRequest, RequestParts};
|
||||
use crate::response::IntoResponse;
|
||||
use async_trait::async_trait;
|
||||
use headers::HeaderMap;
|
||||
use std::ops::Deref;
|
||||
use bytes::Bytes;
|
||||
use headers::HeaderMapExt;
|
||||
use http_body::Full;
|
||||
use std::{convert::Infallible, ops::Deref};
|
||||
|
||||
/// Extractor that extracts a typed header value from [`headers`].
|
||||
///
|
||||
|
@ -36,19 +39,26 @@ where
|
|||
type Rejection = TypedHeaderRejection;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let empty_headers = HeaderMap::new();
|
||||
let header_values = if let Some(headers) = req.headers() {
|
||||
headers.get_all(T::name())
|
||||
let headers = if let Some(headers) = req.headers() {
|
||||
headers
|
||||
} else {
|
||||
empty_headers.get_all(T::name())
|
||||
return Err(TypedHeaderRejection {
|
||||
name: T::name(),
|
||||
reason: Reason::Missing,
|
||||
});
|
||||
};
|
||||
|
||||
T::decode(&mut header_values.iter())
|
||||
.map(Self)
|
||||
.map_err(|err| TypedHeaderRejection {
|
||||
err,
|
||||
match headers.typed_try_get::<T>() {
|
||||
Ok(Some(value)) => Ok(Self(value)),
|
||||
Ok(None) => Err(TypedHeaderRejection {
|
||||
name: T::name(),
|
||||
})
|
||||
reason: Reason::Missing,
|
||||
}),
|
||||
Err(err) => Err(TypedHeaderRejection {
|
||||
name: T::name(),
|
||||
reason: Reason::Error(err),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,3 +69,85 @@ impl<T> Deref for TypedHeader<T> {
|
|||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Rejection used for [`TypedHeader`](super::TypedHeader).
|
||||
#[cfg(feature = "headers")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
||||
#[derive(Debug)]
|
||||
pub struct TypedHeaderRejection {
|
||||
name: &'static http::header::HeaderName,
|
||||
reason: Reason,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Reason {
|
||||
Missing,
|
||||
Error(headers::Error),
|
||||
}
|
||||
|
||||
impl IntoResponse for TypedHeaderRejection {
|
||||
type Body = Full<Bytes>;
|
||||
type BodyError = Infallible;
|
||||
|
||||
fn into_response(self) -> http::Response<Self::Body> {
|
||||
let mut res = self.to_string().into_response();
|
||||
*res.status_mut() = http::StatusCode::BAD_REQUEST;
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TypedHeaderRejection {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.reason {
|
||||
Reason::Missing => {
|
||||
write!(f, "Header of type `{}` was missing", self.name)
|
||||
}
|
||||
Reason::Error(err) => {
|
||||
write!(f, "{} ({})", err, self.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for TypedHeaderRejection {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match &self.reason {
|
||||
Reason::Error(err) => Some(err),
|
||||
Reason::Missing => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{handler::get, response::IntoResponse, route, tests::*};
|
||||
|
||||
#[tokio::test]
|
||||
async fn typed_header() {
|
||||
async fn handle(
|
||||
TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
|
||||
) -> impl IntoResponse {
|
||||
user_agent.to_string()
|
||||
}
|
||||
|
||||
let app = route("/", get(handle));
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}", addr))
|
||||
.header("user-agent", "foobar")
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let body = res.text().await.unwrap();
|
||||
assert_eq!(body, "foobar");
|
||||
|
||||
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
|
||||
let body = res.text().await.unwrap();
|
||||
assert_eq!(body, "Header of type `user-agent` was missing");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -89,7 +89,7 @@
|
|||
use crate::{
|
||||
body::BoxBody,
|
||||
response::IntoResponse,
|
||||
routing::{future::RouteFuture, EmptyRouter, MethodFilter},
|
||||
routing::{EmptyRouter, MethodFilter},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use http::{Request, Response};
|
||||
|
|
|
@ -45,7 +45,7 @@ mod for_services {
|
|||
async fn get_handles_head() {
|
||||
let app = route(
|
||||
"/",
|
||||
get(service_fn(|req: Request<Body>| async move {
|
||||
get(service_fn(|_req: Request<Body>| async move {
|
||||
let res = Response::builder()
|
||||
.header("x-some-header", "foobar".parse::<HeaderValue>().unwrap())
|
||||
.body(Body::from("you shouldn't see this"))
|
||||
|
|
|
@ -420,35 +420,6 @@ async fn middleware_on_single_route() {
|
|||
assert_eq!(body, "Hello, World!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg(feature = "header")]
|
||||
async fn typed_header() {
|
||||
use crate::{extract::TypedHeader, response::IntoResponse};
|
||||
|
||||
async fn handle(TypedHeader(user_agent): TypedHeader<headers::UserAgent>) -> impl IntoResponse {
|
||||
user_agent.to_string()
|
||||
}
|
||||
|
||||
let app = route("/", get(handle));
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}", addr))
|
||||
.header("user-agent", "foobar")
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let body = res.text().await.unwrap();
|
||||
assert_eq!(body, "foobar");
|
||||
|
||||
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
|
||||
let body = res.text().await.unwrap();
|
||||
assert_eq!(body, "invalid HTTP header (user-agent)");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn service_in_bottom() {
|
||||
async fn handler(_req: Request<hyper::Body>) -> Result<Response<hyper::Body>, hyper::Error> {
|
||||
|
|
Loading…
Add table
Reference in a new issue