diff --git a/examples/compression/Cargo.toml b/examples/compression/Cargo.toml new file mode 100644 index 00000000..d5fdcf8e --- /dev/null +++ b/examples/compression/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "example-compression" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +axum-extra = { path = "../../axum-extra", features = ["typed-header"] } +serde_json = "1" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +tower = "0.4" +tower-http = { version = "0.5", features = ["compression-full", "decompression-full"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[dev-dependencies] +assert-json-diff = "2.0" +brotli = "3.4" +flate2 = "1" +http = "1" +zstd = "0.13" diff --git a/examples/compression/README.md b/examples/compression/README.md new file mode 100644 index 00000000..3f0ed94d --- /dev/null +++ b/examples/compression/README.md @@ -0,0 +1,32 @@ +# compression + +This example shows how to: +- automatically decompress request bodies when necessary +- compress response bodies based on the `accept` header. + +## Running + +``` +cargo run -p example-compression +``` + +## Sending compressed requests + +``` +curl -v -g 'http://localhost:3000/' \ + -H "Content-Type: application/json" \ + -H "Content-Encoding: gzip" \ + --compressed \ + --data-binary @data/products.json.gz +``` + +(Notice the `Content-Encoding: gzip` in the request, and `content-encoding: gzip` in the response.) + +## Sending uncompressed requests + +``` +curl -v -g 'http://localhost:3000/' \ + -H "Content-Type: application/json" \ + --compressed \ + --data-binary @data/products.json +``` diff --git a/examples/compression/data/products.json b/examples/compression/data/products.json new file mode 100644 index 00000000..a234fbdd --- /dev/null +++ b/examples/compression/data/products.json @@ -0,0 +1,12 @@ +{ + "products": [ + { + "id": 1, + "name": "Product 1" + }, + { + "id": 2, + "name": "Product 2" + } + ] +} diff --git a/examples/compression/data/products.json.gz b/examples/compression/data/products.json.gz new file mode 100644 index 00000000..91d39895 Binary files /dev/null and b/examples/compression/data/products.json.gz differ diff --git a/examples/compression/src/main.rs b/examples/compression/src/main.rs new file mode 100644 index 00000000..1fa1bb49 --- /dev/null +++ b/examples/compression/src/main.rs @@ -0,0 +1,39 @@ +use axum::{routing::post, Json, Router}; +use serde_json::Value; +use tower::ServiceBuilder; +use tower_http::{compression::CompressionLayer, decompression::RequestDecompressionLayer}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[cfg(test)] +mod tests; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "example-compression=trace".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let app: Router = app(); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + tracing::debug!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +fn app() -> Router { + Router::new().route("/", post(root)).layer( + ServiceBuilder::new() + .layer(RequestDecompressionLayer::new()) + .layer(CompressionLayer::new()), + ) +} + +async fn root(Json(value): Json) -> Json { + Json(value) +} diff --git a/examples/compression/src/tests.rs b/examples/compression/src/tests.rs new file mode 100644 index 00000000..c91ccaa6 --- /dev/null +++ b/examples/compression/src/tests.rs @@ -0,0 +1,245 @@ +use assert_json_diff::assert_json_eq; +use axum::{ + body::{Body, Bytes}, + response::Response, +}; +use brotli::enc::BrotliEncoderParams; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; +use http::{header, StatusCode}; +use serde_json::{json, Value}; +use std::io::{Read, Write}; +use tower::ServiceExt; + +use super::*; + +#[tokio::test] +async fn handle_uncompressed_request_bodies() { + // Given + + let body = json(); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .body(json_body(&body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn decompress_gzip_request_bodies() { + // Given + + let body = compress_gzip(&json()); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_ENCODING, "gzip") + .body(Body::from(body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn decompress_br_request_bodies() { + // Given + + let body = compress_br(&json()); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_ENCODING, "br") + .body(Body::from(body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn decompress_zstd_request_bodies() { + // Given + + let body = compress_zstd(&json()); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_ENCODING, "zstd") + .body(Body::from(body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn do_not_compress_response_bodies() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn compress_response_bodies_with_gzip() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::ACCEPT_ENCODING, "gzip") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + let response_body = byte_from_response(response).await; + let mut decoder = GzDecoder::new(response_body.as_ref()); + let mut decompress_body = String::new(); + decoder.read_to_string(&mut decompress_body).unwrap(); + assert_json_eq!( + serde_json::from_str::(&decompress_body).unwrap(), + json() + ); +} + +#[tokio::test] +async fn compress_response_bodies_with_br() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::ACCEPT_ENCODING, "br") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + let response_body = byte_from_response(response).await; + let mut decompress_body = Vec::new(); + brotli::BrotliDecompress(&mut response_body.as_ref(), &mut decompress_body).unwrap(); + assert_json_eq!( + serde_json::from_slice::(&decompress_body).unwrap(), + json() + ); +} + +#[tokio::test] +async fn compress_response_bodies_with_zstd() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::ACCEPT_ENCODING, "zstd") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + let response_body = byte_from_response(response).await; + let decompress_body = zstd::stream::decode_all(std::io::Cursor::new(response_body)).unwrap(); + assert_json_eq!( + serde_json::from_slice::(&decompress_body).unwrap(), + json() + ); +} + +fn json() -> Value { + json!({ + "name": "foo", + "mainProduct": { + "typeId": "product", + "id": "p1" + }, + }) +} + +fn json_body(input: &Value) -> Body { + Body::from(serde_json::to_vec(&input).unwrap()) +} + +async fn json_from_response(response: Response) -> Value { + let body = byte_from_response(response).await; + body_as_json(body) +} + +async fn byte_from_response(response: Response) -> Bytes { + axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() +} + +fn body_as_json(body: Bytes) -> Value { + serde_json::from_slice(body.as_ref()).unwrap() +} + +fn compress_gzip(json: &Value) -> Vec { + let request_body = serde_json::to_vec(&json).unwrap(); + + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&request_body).unwrap(); + encoder.finish().unwrap() +} + +fn compress_br(json: &Value) -> Vec { + let request_body = serde_json::to_vec(&json).unwrap(); + let mut result = Vec::new(); + + let params = BrotliEncoderParams::default(); + let _ = brotli::enc::BrotliCompress(&mut &request_body[..], &mut result, ¶ms).unwrap(); + + result +} + +fn compress_zstd(json: &Value) -> Vec { + let request_body = serde_json::to_vec(&json).unwrap(); + zstd::stream::encode_all(std::io::Cursor::new(request_body), 4).unwrap() +}