From 995ffc1aa2a46c93017a90b31158591e582b37d4 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 15 Aug 2021 20:27:13 +0200 Subject: [PATCH] Correctly handle HEAD requests (#129) --- CHANGELOG.md | 2 ++ Cargo.toml | 2 +- examples/tokio_postgres.rs | 57 ++++++++++++++++++++++++----- src/handler/future.rs | 15 ++++++-- src/handler/mod.rs | 19 ++++++++-- src/service/future.rs | 13 +++++-- src/service/mod.rs | 19 ++++++++-- src/tests/get_to_head.rs | 73 ++++++++++++++++++++++++++++++++++++++ src/tests/mod.rs | 6 +++- 9 files changed, 184 insertions(+), 22 deletions(-) create mode 100644 src/tests/get_to_head.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index aa28ae70..d28c8f42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 required for returning responses with bodies other than `hyper::Body` from handlers. See the docs for advice on how to implement `IntoResponse` ([#86](https://github.com/tokio-rs/axum/pull/86)) - Replace `body::BoxStdError` with `Error`, which supports downcasting ([#150](https://github.com/tokio-rs/axum/pull/150)) +- `get` routes will now also be called for `HEAD` requests but will always have + the response body removed ([#129](https://github.com/tokio-rs/axum/pull/129)) - Change WebSocket API to use an extractor ([#121](https://github.com/tokio-rs/axum/pull/121)) - Make WebSocket `Message` an enum ([#116](https://github.com/tokio-rs/axum/pull/116)) - `WebSocket` now uses `Error` as its error type ([#150](https://github.com/tokio-rs/axum/pull/150)) diff --git a/Cargo.toml b/Cargo.toml index e20cdc06..853a3d7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ mime = { optional = true, version = "0.3" } [dev-dependencies] askama = "0.10.5" -bb8 = "0.7.0" +bb8 = "0.7.1" bb8-postgres = "0.7.0" futures = "0.3" hyper = { version = "0.14", features = ["full"] } diff --git a/examples/tokio_postgres.rs b/examples/tokio_postgres.rs index 8f915220..794a664b 100644 --- a/examples/tokio_postgres.rs +++ b/examples/tokio_postgres.rs @@ -4,8 +4,13 @@ //! cargo run --example tokio_postgres //! ``` -use axum::{extract::Extension, prelude::*, AddExtensionLayer}; -use bb8::Pool; +use axum::{ + async_trait, + extract::{Extension, FromRequest, RequestParts}, + prelude::*, + AddExtensionLayer, +}; +use bb8::{Pool, PooledConnection}; use bb8_postgres::PostgresConnectionManager; use http::StatusCode; use std::net::SocketAddr; @@ -26,7 +31,11 @@ async fn main() { let pool = Pool::builder().build(manager).await.unwrap(); // build our application with some routes - let app = route("/", get(handler)).layer(AddExtensionLayer::new(pool)); + let app = route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ) + .layer(AddExtensionLayer::new(pool)); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -39,14 +48,10 @@ async fn main() { type ConnectionPool = Pool>; -async fn handler( +// we can exact the connection pool with `Extension` +async fn using_connection_pool_extractor( Extension(pool): Extension, ) -> Result { - // We cannot get a connection directly via an extractor because - // `bb8::PooledConnection` contains a reference to the pool and - // `extract::FromRequest` cannot return types that contain references. - // - // So therefore we have to get a connection from the pool manually. let conn = pool.get().await.map_err(internal_error)?; let row = conn @@ -58,6 +63,40 @@ async fn handler( Ok(two.to_string()) } +// we can also write a custom extractor that grabs a connection from the pool +// which setup is appropriate depends on your application +struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager>); + +#[async_trait] +impl FromRequest for DatabaseConnection +where + B: Send, +{ + type Rejection = (StatusCode, String); + + async fn from_request(req: &mut RequestParts) -> Result { + let Extension(pool) = Extension::::from_request(req) + .await + .map_err(internal_error)?; + + let conn = pool.get_owned().await.map_err(internal_error)?; + + Ok(Self(conn)) + } +} + +async fn using_connection_extractor( + DatabaseConnection(conn): DatabaseConnection, +) -> Result { + let row = conn + .query_one("select 1 + 1", &[]) + .await + .map_err(internal_error)?; + let two: i32 = row.try_get(0).map_err(internal_error)?; + + Ok(two.to_string()) +} + /// Utility function for mapping any error into a `500 Internal Server Error` /// response. fn internal_error(err: E) -> (StatusCode, String) diff --git a/src/handler/future.rs b/src/handler/future.rs index 8368b327..0810dc52 100644 --- a/src/handler/future.rs +++ b/src/handler/future.rs @@ -1,7 +1,8 @@ //! Handler future types. -use crate::body::BoxBody; -use http::{Request, Response}; +use crate::body::{box_body, BoxBody}; +use http::{Method, Request, Response}; +use http_body::Empty; use pin_project_lite::pin_project; use std::{ convert::Infallible, @@ -27,6 +28,7 @@ pin_project! { { #[pin] pub(super) inner: crate::routing::future::RouteFuture, + pub(super) req_method: Method, } } @@ -38,6 +40,13 @@ where type Output = Result, S::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().inner.poll(cx) + let this = self.project(); + let response = futures_util::ready!(this.inner.poll(cx))?; + if this.req_method == &Method::HEAD { + let response = response.map(|_| box_body(Empty::new())); + Poll::Ready(Ok(response)) + } else { + Poll::Ready(Ok(response)) + } } } diff --git a/src/handler/mod.rs b/src/handler/mod.rs index e32e3f41..c94f69ea 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -79,11 +79,15 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` +/// +/// Note that `get` routes will also be called for `HEAD` requests but will have +/// the response body removed. Make sure to add explicit `HEAD` routes +/// afterwards. pub fn get(handler: H) -> OnMethod, EmptyRouter> where H: Handler, { - on(MethodFilter::GET, handler) + on(MethodFilter::GET | MethodFilter::HEAD, handler) } /// Route `HEAD` requests to the given handler. @@ -508,11 +512,15 @@ impl OnMethod { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` + /// + /// Note that `get` routes will also be called for `HEAD` requests but will have + /// the response body removed. Make sure to add explicit `HEAD` routes + /// afterwards. pub fn get(self, handler: H) -> OnMethod, Self> where H: Handler, { - self.on(MethodFilter::GET, handler) + self.on(MethodFilter::GET | MethodFilter::HEAD, handler) } /// Chain an additional handler that will only accept `HEAD` requests. @@ -624,6 +632,8 @@ where } fn call(&mut self, req: Request) -> Self::Future { + let req_method = req.method().clone(); + let f = if self.method.matches(req.method()) { let fut = self.svc.clone().oneshot(req); RouteFuture::a(fut) @@ -632,6 +642,9 @@ where RouteFuture::b(fut) }; - future::OnMethodFuture { inner: f } + future::OnMethodFuture { + inner: f, + req_method, + } } } diff --git a/src/service/future.rs b/src/service/future.rs index b4ac261e..d1c0b89e 100644 --- a/src/service/future.rs +++ b/src/service/future.rs @@ -6,7 +6,8 @@ use crate::{ }; use bytes::Bytes; use futures_util::ready; -use http::{Request, Response}; +use http::{Method, Request, Response}; +use http_body::Empty; use pin_project_lite::pin_project; use std::{ future::Future, @@ -85,6 +86,7 @@ pin_project! { { #[pin] pub(super) inner: crate::routing::future::RouteFuture, + pub(super) req_method: Method, } } @@ -96,6 +98,13 @@ where type Output = Result, S::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().inner.poll(cx) + let this = self.project(); + let response = futures_util::ready!(this.inner.poll(cx))?; + if this.req_method == &Method::HEAD { + let response = response.map(|_| box_body(Empty::new())); + Poll::Ready(Ok(response)) + } else { + Poll::Ready(Ok(response)) + } } } diff --git a/src/service/mod.rs b/src/service/mod.rs index fd21d33a..ea7bed05 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -152,11 +152,15 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` +/// +/// Note that `get` routes will also be called for `HEAD` requests but will have +/// the response body removed. Make sure to add explicit `HEAD` routes +/// afterwards. pub fn get(svc: S) -> OnMethod, EmptyRouter> where S: Service> + Clone, { - on(MethodFilter::GET, svc) + on(MethodFilter::GET | MethodFilter::HEAD, svc) } /// Route `HEAD` requests to the given service. @@ -322,11 +326,15 @@ impl OnMethod { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` + /// + /// Note that `get` routes will also be called for `HEAD` requests but will have + /// the response body removed. Make sure to add explicit `HEAD` routes + /// afterwards. pub fn get(self, svc: T) -> OnMethod, Self> where T: Service> + Clone, { - self.on(MethodFilter::GET, svc) + self.on(MethodFilter::GET | MethodFilter::HEAD, svc) } /// Chain an additional service that will only accept `HEAD` requests. @@ -465,6 +473,8 @@ where } fn call(&mut self, req: Request) -> Self::Future { + let req_method = req.method().clone(); + let f = if self.method.matches(req.method()) { let fut = self.svc.clone().oneshot(req); RouteFuture::a(fut) @@ -473,7 +483,10 @@ where RouteFuture::b(fut) }; - future::OnMethodFuture { inner: f } + future::OnMethodFuture { + inner: f, + req_method, + } } } diff --git a/src/tests/get_to_head.rs b/src/tests/get_to_head.rs new file mode 100644 index 00000000..55c49bb0 --- /dev/null +++ b/src/tests/get_to_head.rs @@ -0,0 +1,73 @@ +use super::*; +use http::Method; +use tower::ServiceExt; + +mod for_handlers { + use super::*; + + #[tokio::test] + async fn get_handles_head() { + let app = route( + "/", + get(|| async { + let mut headers = HeaderMap::new(); + headers.insert("x-some-header", "foobar".parse().unwrap()); + (headers, "you shouldn't see this") + }), + ); + + // don't use reqwest because it always strips bodies from HEAD responses + let res = app + .oneshot( + Request::builder() + .uri("/") + .method(Method::HEAD) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["x-some-header"], "foobar"); + + let body = hyper::body::to_bytes(res.into_body()).await.unwrap(); + assert_eq!(body.len(), 0); + } +} + +mod for_services { + use super::*; + use crate::service::get; + + #[tokio::test] + async fn get_handles_head() { + let app = route( + "/", + get((|| async { + let mut headers = HeaderMap::new(); + headers.insert("x-some-header", "foobar".parse().unwrap()); + (headers, "you shouldn't see this") + }) + .into_service()), + ); + + // don't use reqwest because it always strips bodies from HEAD responses + let res = app + .oneshot( + Request::builder() + .uri("/") + .method(Method::HEAD) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.headers()["x-some-header"], "foobar"); + + let body = hyper::body::to_bytes(res.into_body()).await.unwrap(); + assert_eq!(body.len(), 0); + } +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index bfaea8c1..caa4aaad 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -5,7 +5,10 @@ use crate::{ }; use bytes::Bytes; use futures_util::future::Ready; -use http::{header::AUTHORIZATION, Request, Response, StatusCode, Uri}; +use http::{ + header::{HeaderMap, AUTHORIZATION}, + Request, Response, StatusCode, Uri, +}; use hyper::{Body, Server}; use serde::Deserialize; use serde_json::json; @@ -18,6 +21,7 @@ use std::{ }; use tower::{make::Shared, service_fn, BoxError, Service}; +mod get_to_head; mod handle_error; mod nest; mod or;