diff --git a/examples/reqwest-response/Cargo.toml b/examples/reqwest-response/Cargo.toml new file mode 100644 index 00000000..9f47fd63 --- /dev/null +++ b/examples/reqwest-response/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "example-reqwest-response" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +reqwest = { version = "0.11", features = ["stream"] } +tokio = { version = "1.0", features = ["full"] } +tokio-stream = "0.1" +tower-http = { version = "0.4", features = ["trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/reqwest-response/src/main.rs b/examples/reqwest-response/src/main.rs new file mode 100644 index 00000000..db703ee5 --- /dev/null +++ b/examples/reqwest-response/src/main.rs @@ -0,0 +1,79 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-reqwest-response +//! ``` + +use std::{convert::Infallible, time::Duration}; + +use axum::{ + body::{Body, Bytes}, + extract::State, + response::{IntoResponse, Response}, + routing::get, + Router, +}; +use reqwest::{Client, StatusCode}; +use tokio_stream::StreamExt; +use tower_http::trace::TraceLayer; +use tracing::Span; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "example_reqwest_response=debug,tower_http=debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let client = Client::new(); + + let app = Router::new() + .route("/", get(proxy_via_reqwest)) + .route("/stream", get(stream_some_data)) + // Add some logging so we can see the streams going through + .layer(TraceLayer::new_for_http().on_body_chunk( + |chunk: &Bytes, _latency: Duration, _span: &Span| { + tracing::debug!("streaming {} bytes", chunk.len()); + }, + )) + .with_state(client); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + tracing::debug!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +async fn proxy_via_reqwest(State(client): State) -> Response { + let reqwest_response = match client.get("http://127.0.0.1:3000/stream").send().await { + Ok(res) => res, + Err(err) => { + tracing::error!(%err, "request failed"); + return StatusCode::BAD_GATEWAY.into_response(); + } + }; + + let mut response_builder = Response::builder().status(reqwest_response.status()); + + // This unwrap is fine because we haven't insert any headers yet so there can't be any invalid + // headers + *response_builder.headers_mut().unwrap() = reqwest_response.headers().clone(); + + response_builder + .body(Body::from_stream(reqwest_response.bytes_stream())) + // Same goes for this unwrap + .unwrap() +} + +async fn stream_some_data() -> Body { + let stream = tokio_stream::iter(0..5) + .throttle(Duration::from_secs(1)) + .map(|n| n.to_string()) + .map(Ok::<_, Infallible>); + Body::from_stream(stream) +}