Correctly handle HEAD requests (#129)

This commit is contained in:
David Pedersen 2021-08-15 20:27:13 +02:00 committed by GitHub
parent 9cd543401f
commit 995ffc1aa2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 184 additions and 22 deletions

View file

@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
required for returning responses with bodies other than `hyper::Body` from
handlers. See the docs for advice on how to implement `IntoResponse` ([#86](https://github.com/tokio-rs/axum/pull/86))
- Replace `body::BoxStdError` with `Error`, which supports downcasting ([#150](https://github.com/tokio-rs/axum/pull/150))
- `get` routes will now also be called for `HEAD` requests but will always have
the response body removed ([#129](https://github.com/tokio-rs/axum/pull/129))
- Change WebSocket API to use an extractor ([#121](https://github.com/tokio-rs/axum/pull/121))
- Make WebSocket `Message` an enum ([#116](https://github.com/tokio-rs/axum/pull/116))
- `WebSocket` now uses `Error` as its error type ([#150](https://github.com/tokio-rs/axum/pull/150))

View file

@ -45,7 +45,7 @@ mime = { optional = true, version = "0.3" }
[dev-dependencies]
askama = "0.10.5"
bb8 = "0.7.0"
bb8 = "0.7.1"
bb8-postgres = "0.7.0"
futures = "0.3"
hyper = { version = "0.14", features = ["full"] }

View file

@ -4,8 +4,13 @@
//! cargo run --example tokio_postgres
//! ```
use axum::{extract::Extension, prelude::*, AddExtensionLayer};
use bb8::Pool;
use axum::{
async_trait,
extract::{Extension, FromRequest, RequestParts},
prelude::*,
AddExtensionLayer,
};
use bb8::{Pool, PooledConnection};
use bb8_postgres::PostgresConnectionManager;
use http::StatusCode;
use std::net::SocketAddr;
@ -26,7 +31,11 @@ async fn main() {
let pool = Pool::builder().build(manager).await.unwrap();
// build our application with some routes
let app = route("/", get(handler)).layer(AddExtensionLayer::new(pool));
let app = route(
"/",
get(using_connection_pool_extractor).post(using_connection_extractor),
)
.layer(AddExtensionLayer::new(pool));
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
@ -39,14 +48,10 @@ async fn main() {
type ConnectionPool = Pool<PostgresConnectionManager<NoTls>>;
async fn handler(
// we can exact the connection pool with `Extension`
async fn using_connection_pool_extractor(
Extension(pool): Extension<ConnectionPool>,
) -> Result<String, (StatusCode, String)> {
// We cannot get a connection directly via an extractor because
// `bb8::PooledConnection` contains a reference to the pool and
// `extract::FromRequest` cannot return types that contain references.
//
// So therefore we have to get a connection from the pool manually.
let conn = pool.get().await.map_err(internal_error)?;
let row = conn
@ -58,6 +63,40 @@ async fn handler(
Ok(two.to_string())
}
// we can also write a custom extractor that grabs a connection from the pool
// which setup is appropriate depends on your application
struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager<NoTls>>);
#[async_trait]
impl<B> FromRequest<B> for DatabaseConnection
where
B: Send,
{
type Rejection = (StatusCode, String);
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let Extension(pool) = Extension::<ConnectionPool>::from_request(req)
.await
.map_err(internal_error)?;
let conn = pool.get_owned().await.map_err(internal_error)?;
Ok(Self(conn))
}
}
async fn using_connection_extractor(
DatabaseConnection(conn): DatabaseConnection,
) -> Result<String, (StatusCode, String)> {
let row = conn
.query_one("select 1 + 1", &[])
.await
.map_err(internal_error)?;
let two: i32 = row.try_get(0).map_err(internal_error)?;
Ok(two.to_string())
}
/// Utility function for mapping any error into a `500 Internal Server Error`
/// response.
fn internal_error<E>(err: E) -> (StatusCode, String)

View file

@ -1,7 +1,8 @@
//! Handler future types.
use crate::body::BoxBody;
use http::{Request, Response};
use crate::body::{box_body, BoxBody};
use http::{Method, Request, Response};
use http_body::Empty;
use pin_project_lite::pin_project;
use std::{
convert::Infallible,
@ -27,6 +28,7 @@ pin_project! {
{
#[pin]
pub(super) inner: crate::routing::future::RouteFuture<S, F, B>,
pub(super) req_method: Method,
}
}
@ -38,6 +40,13 @@ where
type Output = Result<Response<BoxBody>, S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().inner.poll(cx)
let this = self.project();
let response = futures_util::ready!(this.inner.poll(cx))?;
if this.req_method == &Method::HEAD {
let response = response.map(|_| box_body(Empty::new()));
Poll::Ready(Ok(response))
} else {
Poll::Ready(Ok(response))
}
}
}

View file

@ -79,11 +79,15 @@ where
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Note that `get` routes will also be called for `HEAD` requests but will have
/// the response body removed. Make sure to add explicit `HEAD` routes
/// afterwards.
pub fn get<H, B, T>(handler: H) -> OnMethod<IntoService<H, B, T>, EmptyRouter>
where
H: Handler<B, T>,
{
on(MethodFilter::GET, handler)
on(MethodFilter::GET | MethodFilter::HEAD, handler)
}
/// Route `HEAD` requests to the given handler.
@ -508,11 +512,15 @@ impl<S, F> OnMethod<S, F> {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Note that `get` routes will also be called for `HEAD` requests but will have
/// the response body removed. Make sure to add explicit `HEAD` routes
/// afterwards.
pub fn get<H, B, T>(self, handler: H) -> OnMethod<IntoService<H, B, T>, Self>
where
H: Handler<B, T>,
{
self.on(MethodFilter::GET, handler)
self.on(MethodFilter::GET | MethodFilter::HEAD, handler)
}
/// Chain an additional handler that will only accept `HEAD` requests.
@ -624,6 +632,8 @@ where
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let req_method = req.method().clone();
let f = if self.method.matches(req.method()) {
let fut = self.svc.clone().oneshot(req);
RouteFuture::a(fut)
@ -632,6 +642,9 @@ where
RouteFuture::b(fut)
};
future::OnMethodFuture { inner: f }
future::OnMethodFuture {
inner: f,
req_method,
}
}
}

View file

@ -6,7 +6,8 @@ use crate::{
};
use bytes::Bytes;
use futures_util::ready;
use http::{Request, Response};
use http::{Method, Request, Response};
use http_body::Empty;
use pin_project_lite::pin_project;
use std::{
future::Future,
@ -85,6 +86,7 @@ pin_project! {
{
#[pin]
pub(super) inner: crate::routing::future::RouteFuture<S, F, B>,
pub(super) req_method: Method,
}
}
@ -96,6 +98,13 @@ where
type Output = Result<Response<BoxBody>, S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().inner.poll(cx)
let this = self.project();
let response = futures_util::ready!(this.inner.poll(cx))?;
if this.req_method == &Method::HEAD {
let response = response.map(|_| box_body(Empty::new()));
Poll::Ready(Ok(response))
} else {
Poll::Ready(Ok(response))
}
}
}

View file

@ -152,11 +152,15 @@ where
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Note that `get` routes will also be called for `HEAD` requests but will have
/// the response body removed. Make sure to add explicit `HEAD` routes
/// afterwards.
pub fn get<S, B>(svc: S) -> OnMethod<BoxResponseBody<S, B>, EmptyRouter<S::Error>>
where
S: Service<Request<B>> + Clone,
{
on(MethodFilter::GET, svc)
on(MethodFilter::GET | MethodFilter::HEAD, svc)
}
/// Route `HEAD` requests to the given service.
@ -322,11 +326,15 @@ impl<S, F> OnMethod<S, F> {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Note that `get` routes will also be called for `HEAD` requests but will have
/// the response body removed. Make sure to add explicit `HEAD` routes
/// afterwards.
pub fn get<T, B>(self, svc: T) -> OnMethod<BoxResponseBody<T, B>, Self>
where
T: Service<Request<B>> + Clone,
{
self.on(MethodFilter::GET, svc)
self.on(MethodFilter::GET | MethodFilter::HEAD, svc)
}
/// Chain an additional service that will only accept `HEAD` requests.
@ -465,6 +473,8 @@ where
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let req_method = req.method().clone();
let f = if self.method.matches(req.method()) {
let fut = self.svc.clone().oneshot(req);
RouteFuture::a(fut)
@ -473,7 +483,10 @@ where
RouteFuture::b(fut)
};
future::OnMethodFuture { inner: f }
future::OnMethodFuture {
inner: f,
req_method,
}
}
}

73
src/tests/get_to_head.rs Normal file
View file

@ -0,0 +1,73 @@
use super::*;
use http::Method;
use tower::ServiceExt;
mod for_handlers {
use super::*;
#[tokio::test]
async fn get_handles_head() {
let app = route(
"/",
get(|| async {
let mut headers = HeaderMap::new();
headers.insert("x-some-header", "foobar".parse().unwrap());
(headers, "you shouldn't see this")
}),
);
// don't use reqwest because it always strips bodies from HEAD responses
let res = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::HEAD)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.headers()["x-some-header"], "foobar");
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(body.len(), 0);
}
}
mod for_services {
use super::*;
use crate::service::get;
#[tokio::test]
async fn get_handles_head() {
let app = route(
"/",
get((|| async {
let mut headers = HeaderMap::new();
headers.insert("x-some-header", "foobar".parse().unwrap());
(headers, "you shouldn't see this")
})
.into_service()),
);
// don't use reqwest because it always strips bodies from HEAD responses
let res = app
.oneshot(
Request::builder()
.uri("/")
.method(Method::HEAD)
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.headers()["x-some-header"], "foobar");
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(body.len(), 0);
}
}

View file

@ -5,7 +5,10 @@ use crate::{
};
use bytes::Bytes;
use futures_util::future::Ready;
use http::{header::AUTHORIZATION, Request, Response, StatusCode, Uri};
use http::{
header::{HeaderMap, AUTHORIZATION},
Request, Response, StatusCode, Uri,
};
use hyper::{Body, Server};
use serde::Deserialize;
use serde_json::json;
@ -18,6 +21,7 @@ use std::{
};
use tower::{make::Shared, service_fn, BoxError, Service};
mod get_to_head;
mod handle_error;
mod nest;
mod or;