1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-02-19 11:39:53 +01:00

Refactor TypedHeader extractor ()

I should use `HeaderMapExt::typed_try_get` rather than implementing it
manually.
This commit is contained in:
David Pedersen 2021-08-16 09:05:10 +02:00 committed by GitHub
parent 48afd30491
commit be7e9e9bc6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 78 deletions

View file

@ -341,40 +341,6 @@ where
}
}
/// Rejection used for [`TypedHeader`](super::TypedHeader).
#[cfg(feature = "headers")]
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
#[derive(Debug)]
pub struct TypedHeaderRejection {
pub(super) name: &'static http::header::HeaderName,
pub(super) err: headers::Error,
}
#[cfg(feature = "headers")]
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
impl IntoResponse for TypedHeaderRejection {
type Body = Full<Bytes>;
type BodyError = Infallible;
fn into_response(self) -> http::Response<Self::Body> {
let mut res = self.to_string().into_response();
*res.status_mut() = http::StatusCode::BAD_REQUEST;
res
}
}
#[cfg(feature = "headers")]
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
impl std::fmt::Display for TypedHeaderRejection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} ({})", self.err, self.name)
}
}
#[cfg(feature = "headers")]
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
impl std::error::Error for TypedHeaderRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.err)
}
}
pub use super::typed_header::TypedHeaderRejection;

View file

@ -1,7 +1,10 @@
use super::{rejection::TypedHeaderRejection, FromRequest, RequestParts};
use super::{FromRequest, RequestParts};
use crate::response::IntoResponse;
use async_trait::async_trait;
use headers::HeaderMap;
use std::ops::Deref;
use bytes::Bytes;
use headers::HeaderMapExt;
use http_body::Full;
use std::{convert::Infallible, ops::Deref};
/// Extractor that extracts a typed header value from [`headers`].
///
@ -36,19 +39,26 @@ where
type Rejection = TypedHeaderRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let empty_headers = HeaderMap::new();
let header_values = if let Some(headers) = req.headers() {
headers.get_all(T::name())
let headers = if let Some(headers) = req.headers() {
headers
} else {
empty_headers.get_all(T::name())
return Err(TypedHeaderRejection {
name: T::name(),
reason: Reason::Missing,
});
};
T::decode(&mut header_values.iter())
.map(Self)
.map_err(|err| TypedHeaderRejection {
err,
match headers.typed_try_get::<T>() {
Ok(Some(value)) => Ok(Self(value)),
Ok(None) => Err(TypedHeaderRejection {
name: T::name(),
})
reason: Reason::Missing,
}),
Err(err) => Err(TypedHeaderRejection {
name: T::name(),
reason: Reason::Error(err),
}),
}
}
}
@ -59,3 +69,85 @@ impl<T> Deref for TypedHeader<T> {
&self.0
}
}
/// Rejection used for [`TypedHeader`](super::TypedHeader).
#[cfg(feature = "headers")]
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
#[derive(Debug)]
pub struct TypedHeaderRejection {
name: &'static http::header::HeaderName,
reason: Reason,
}
#[derive(Debug)]
enum Reason {
Missing,
Error(headers::Error),
}
impl IntoResponse for TypedHeaderRejection {
type Body = Full<Bytes>;
type BodyError = Infallible;
fn into_response(self) -> http::Response<Self::Body> {
let mut res = self.to_string().into_response();
*res.status_mut() = http::StatusCode::BAD_REQUEST;
res
}
}
impl std::fmt::Display for TypedHeaderRejection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.reason {
Reason::Missing => {
write!(f, "Header of type `{}` was missing", self.name)
}
Reason::Error(err) => {
write!(f, "{} ({})", err, self.name)
}
}
}
}
impl std::error::Error for TypedHeaderRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.reason {
Reason::Error(err) => Some(err),
Reason::Missing => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{handler::get, response::IntoResponse, route, tests::*};
#[tokio::test]
async fn typed_header() {
async fn handle(
TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
) -> impl IntoResponse {
user_agent.to_string()
}
let app = route("/", get(handle));
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}", addr))
.header("user-agent", "foobar")
.send()
.await
.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "foobar");
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "Header of type `user-agent` was missing");
}
}

View file

@ -89,7 +89,7 @@
use crate::{
body::BoxBody,
response::IntoResponse,
routing::{future::RouteFuture, EmptyRouter, MethodFilter},
routing::{EmptyRouter, MethodFilter},
};
use bytes::Bytes;
use http::{Request, Response};

View file

@ -45,7 +45,7 @@ mod for_services {
async fn get_handles_head() {
let app = route(
"/",
get(service_fn(|req: Request<Body>| async move {
get(service_fn(|_req: Request<Body>| async move {
let res = Response::builder()
.header("x-some-header", "foobar".parse::<HeaderValue>().unwrap())
.body(Body::from("you shouldn't see this"))

View file

@ -420,35 +420,6 @@ async fn middleware_on_single_route() {
assert_eq!(body, "Hello, World!");
}
#[tokio::test]
#[cfg(feature = "header")]
async fn typed_header() {
use crate::{extract::TypedHeader, response::IntoResponse};
async fn handle(TypedHeader(user_agent): TypedHeader<headers::UserAgent>) -> impl IntoResponse {
user_agent.to_string()
}
let app = route("/", get(handle));
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}", addr))
.header("user-agent", "foobar")
.send()
.await
.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "foobar");
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "invalid HTTP header (user-agent)");
}
#[tokio::test]
async fn service_in_bottom() {
async fn handler(_req: Request<hyper::Body>) -> Result<Response<hyper::Body>, hyper::Error> {