use bytes::Bytes; use http::{Request, StatusCode}; use hyper::Server; use std::{ collections::HashMap, net::SocketAddr, sync::{Arc, Mutex}, time::Duration, }; use tower::{make::Shared, ServiceBuilder}; use tower_http::{ add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer, }; use tower_web::{ body::Body, extract::{BytesMaxLength, Extension, UrlParams}, handler::Handler, }; #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); // build our application with some routes let app = tower_web::app() .at("/:key") .get(get.layer(CompressionLayer::new())) .post(set) // convert it into a `Service` .into_service(); // add some middleware let app = ServiceBuilder::new() .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) .layer(AddExtensionLayer::new(SharedState::default())) .service(app); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); let server = Server::bind(&addr).serve(Shared::new(app)); server.await.unwrap(); } type SharedState = Arc>; #[derive(Default)] struct State { db: HashMap, } async fn get( _req: Request, UrlParams((key,)): UrlParams<(String,)>, Extension(state): Extension, ) -> Result { let db = &state.lock().unwrap().db; if let Some(value) = db.get(&key) { Ok(value.clone()) } else { Err(StatusCode::NOT_FOUND) } } async fn set( _req: Request, UrlParams((key,)): UrlParams<(String,)>, BytesMaxLength(value): BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb Extension(state): Extension, ) { state.lock().unwrap().db.insert(key, value); }