diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index aad94047..50edeb42 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -28,6 +28,7 @@ json-lines = [ "tokio-util?/io", "tokio-stream?/io-util" ] +protobuf = ["dep:prost"] query = ["dep:serde", "dep:serde_html_form"] spa = ["tower-http/fs"] typed-routing = ["dep:axum-macros", "dep:serde", "dep:percent-encoding"] @@ -49,6 +50,7 @@ tower-service = "0.3" axum-macros = { path = "../axum-macros", version = "0.2.2", optional = true } cookie = { package = "cookie", version = "0.16", features = ["percent-encode"], optional = true } percent-encoding = { version = "2.1", optional = true } +prost = { version = "0.11", optional = true } serde = { version = "1.0", optional = true } serde_html_form = { version = "0.1", optional = true } serde_json = { version = "1.0.71", optional = true } diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index 6f7ac4e2..f057f468 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -78,6 +78,9 @@ pub mod routing; #[cfg(feature = "json-lines")] pub mod json_lines; +#[cfg(feature = "protobuf")] +pub mod protobuf; + /// Combines two extractors or responses into a single type. #[derive(Debug, Copy, Clone)] pub enum Either { diff --git a/axum-extra/src/protobuf.rs b/axum-extra/src/protobuf.rs new file mode 100644 index 00000000..823c98e3 --- /dev/null +++ b/axum-extra/src/protobuf.rs @@ -0,0 +1,319 @@ +//! Protocol Buffer extractor and response. + +use axum::{ + async_trait, + body::{Bytes, HttpBody}, + extract::{rejection::BytesRejection, FromRequest, RequestParts}, + response::{IntoResponse, Response}, + BoxError, +}; +use bytes::BytesMut; +use http::StatusCode; +use prost::Message; +use std::ops::{Deref, DerefMut}; + +/// A Protocol Buffer message extractor and response. +/// +/// This can be used both as an extractor and as a response. +/// +/// # As extractor +/// +/// When used as an extractor, it can decode request bodies into some type that +/// implements [`prost::Message`]. The request will be rejected (and a [`ProtoBufRejection`] will +/// be returned) if: +/// +/// - The body couldn't be decoded into the target Protocol Buffer message type. +/// - Buffering the request body fails. +/// +/// See [`ProtoBufRejection`] for more details. +/// +/// The extractor does not expect a `Content-Type` header to be present in the request. +/// +/// # Extractor example +/// +/// ```rust,no_run +/// use axum::{routing::post, Router}; +/// use axum_extra::protobuf::ProtoBuf; +/// +/// #[derive(prost::Message)] +/// struct CreateUser { +/// #[prost(string, tag="1")] +/// email: String, +/// #[prost(string, tag="2")] +/// password: String, +/// } +/// +/// async fn create_user(ProtoBuf(payload): ProtoBuf) { +/// // payload is `CreateUser` +/// } +/// +/// let app = Router::new().route("/users", post(create_user)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// # As response +/// +/// When used as a response, it can encode any type that implements [`prost::Message`] to +/// a newly allocated buffer. +/// +/// If no `Content-Type` header is set, the `Content-Type: application/octet-stream` header +/// will be used automatically. +/// +/// # Response example +/// +/// ``` +/// use axum::{ +/// extract::Path, +/// routing::get, +/// Router, +/// }; +/// use axum_extra::protobuf::ProtoBuf; +/// +/// #[derive(prost::Message)] +/// struct User { +/// #[prost(string, tag="1")] +/// username: String, +/// } +/// +/// async fn get_user(Path(user_id) : Path) -> ProtoBuf { +/// let user = find_user(user_id).await; +/// ProtoBuf(user) +/// } +/// +/// async fn find_user(user_id: String) -> User { +/// // ... +/// # unimplemented!() +/// } +/// +/// let app = Router::new().route("/users/:id", get(get_user)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +#[derive(Debug, Clone, Copy, Default)] +#[cfg_attr(docsrs, doc(cfg(feature = "protobuf")))] +pub struct ProtoBuf(pub T); + +#[async_trait] +impl FromRequest for ProtoBuf +where + T: Message + Default, + B: HttpBody + Send, + B::Data: Send, + B::Error: Into, +{ + type Rejection = ProtoBufRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let mut bytes = Bytes::from_request(req).await?; + + match T::decode(&mut bytes) { + Ok(value) => Ok(ProtoBuf(value)), + Err(err) => Err(ProtoBufDecodeError::from_err(err).into()), + } + } +} + +impl Deref for ProtoBuf { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ProtoBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for ProtoBuf { + fn from(inner: T) -> Self { + Self(inner) + } +} + +impl IntoResponse for ProtoBuf +where + T: Message + Default, +{ + fn into_response(self) -> Response { + let mut buf = BytesMut::with_capacity(128); + match &self.0.encode(&mut buf) { + Ok(()) => buf.into_response(), + Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(), + } + } +} + +/// Rejection type for [`ProtoBuf`]. +/// +/// This rejection is used if the request body couldn't be decoded into the target type. +#[derive(Debug)] +pub struct ProtoBufDecodeError(pub(crate) axum::Error); + +impl ProtoBufDecodeError { + pub(crate) fn from_err(err: E) -> Self + where + E: Into, + { + Self(axum::Error::new(err)) + } +} + +impl std::fmt::Display for ProtoBufDecodeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Failed to decode the body: {:?}", self.0) + } +} + +impl std::error::Error for ProtoBufDecodeError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.0) + } +} + +impl IntoResponse for ProtoBufDecodeError { + fn into_response(self) -> Response { + StatusCode::UNPROCESSABLE_ENTITY.into_response() + } +} + +/// Rejection used for [`ProtoBuf`]. +/// +/// Contains one variant for each way the [`ProtoBuf`] extractor +/// can fail. +#[derive(Debug)] +pub enum ProtoBufRejection { + #[allow(missing_docs)] + ProtoBufDecodeError(ProtoBufDecodeError), + #[allow(missing_docs)] + BytesRejection(BytesRejection), +} + +impl From for ProtoBufRejection { + fn from(inner: ProtoBufDecodeError) -> Self { + Self::ProtoBufDecodeError(inner) + } +} + +impl From for ProtoBufRejection { + fn from(inner: BytesRejection) -> Self { + Self::BytesRejection(inner) + } +} + +impl IntoResponse for ProtoBufRejection { + fn into_response(self) -> Response { + match self { + Self::ProtoBufDecodeError(inner) => inner.into_response(), + Self::BytesRejection(inner) => inner.into_response(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::*; + use axum::{routing::post, Router}; + use http::StatusCode; + + #[tokio::test] + async fn decode_body() { + #[derive(prost::Message)] + struct Input { + #[prost(string, tag = "1")] + foo: String, + } + + let app = Router::new().route( + "/", + post(|input: ProtoBuf| async move { input.foo.to_owned() }), + ); + + let input = Input { + foo: "bar".to_string(), + }; + + let client = TestClient::new(app); + let res = client.post("/").body(input.encode_to_vec()).send().await; + + let body = res.text().await; + + assert_eq!(body, "bar"); + } + + #[tokio::test] + async fn prost_decode_error() { + #[derive(prost::Message)] + struct Input { + #[prost(string, tag = "1")] + foo: String, + } + + #[derive(prost::Message)] + struct Expected { + #[prost(int32, tag = "1")] + test: i32, + } + + let app = Router::new().route("/", post(|_: ProtoBuf| async {})); + + let input = Input { + foo: "bar".to_string(), + }; + + let client = TestClient::new(app); + let res = client.post("/").body(input.encode_to_vec()).send().await; + + assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); + } + + #[tokio::test] + async fn encode_body() { + #[derive(prost::Message)] + struct Input { + #[prost(string, tag = "1")] + foo: String, + } + + #[derive(prost::Message)] + struct Output { + #[prost(string, tag = "1")] + result: String, + } + + let app = Router::new().route( + "/", + post(|input: ProtoBuf| async move { + let output = Output { + result: input.foo.to_owned(), + }; + + ProtoBuf(output) + }), + ); + + let input = Input { + foo: "bar".to_string(), + }; + + let client = TestClient::new(app); + let res = client.post("/").body(input.encode_to_vec()).send().await; + + assert_eq!( + res.headers()["content-type"], + mime::APPLICATION_OCTET_STREAM.as_ref() + ); + + let body = res.bytes().await; + + let output = Output::decode(body).unwrap(); + + assert_eq!(output.result, "bar"); + } +} diff --git a/axum/src/test_helpers/test_client.rs b/axum/src/test_helpers/test_client.rs index d52c61a6..f54b788c 100644 --- a/axum/src/test_helpers/test_client.rs +++ b/axum/src/test_helpers/test_client.rs @@ -121,6 +121,11 @@ pub(crate) struct TestResponse { } impl TestResponse { + #[allow(dead_code)] + pub(crate) async fn bytes(self) -> Bytes { + self.response.bytes().await.unwrap() + } + pub(crate) async fn text(self) -> String { self.response.text().await.unwrap() }