mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-30 11:19:20 +02:00
Add example
This commit is contained in:
parent
f4268471b6
commit
7328127a3d
6 changed files with 202 additions and 58 deletions
|
@ -19,8 +19,10 @@ thiserror = "1.0"
|
|||
tower = { version = "0.4", features = ["util"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.6.1", features = ["macros", "rt"] }
|
||||
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tower = { version = "0.4", features = ["util", "make", "timeout"] }
|
||||
tower-http = { version = "0.1", features = ["trace", "compression", "add-extension"] }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.2"
|
||||
|
|
89
examples/key_value_store.rs
Normal file
89
examples/key_value_store.rs
Normal file
|
@ -0,0 +1,89 @@
|
|||
use bytes::Bytes;
|
||||
use http::{Request, StatusCode};
|
||||
use hyper::Server;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::SocketAddr,
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
};
|
||||
use tower::{make::Shared, ServiceBuilder};
|
||||
use tower_http::{
|
||||
add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer,
|
||||
};
|
||||
use tower_web::{body::Body, extract, response, Error};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
// build our application with some routes
|
||||
let app = tower_web::app()
|
||||
.at("/get")
|
||||
.get(get)
|
||||
.at("/set")
|
||||
.post(set)
|
||||
// convert it into a `Service`
|
||||
.into_service();
|
||||
|
||||
// add some middleware
|
||||
let app = ServiceBuilder::new()
|
||||
.timeout(Duration::from_secs(10))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(CompressionLayer::new())
|
||||
.layer(AddExtensionLayer::new(SharedState::default()))
|
||||
.service(app);
|
||||
|
||||
// run it with hyper
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
tracing::debug!("listening on {}", addr);
|
||||
let server = Server::bind(&addr).serve(Shared::new(app));
|
||||
server.await.unwrap();
|
||||
}
|
||||
|
||||
type SharedState = Arc<Mutex<State>>;
|
||||
|
||||
#[derive(Default)]
|
||||
struct State {
|
||||
db: HashMap<String, Bytes>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct GetSetQueryString {
|
||||
key: String,
|
||||
}
|
||||
|
||||
async fn get(
|
||||
_req: Request<Body>,
|
||||
query: extract::Query<GetSetQueryString>,
|
||||
state: extract::Extension<SharedState>,
|
||||
) -> Result<Bytes, Error> {
|
||||
let state = state.into_inner();
|
||||
let db = &state.lock().unwrap().db;
|
||||
|
||||
let key = query.into_inner().key;
|
||||
|
||||
if let Some(value) = db.get(&key) {
|
||||
Ok(value.clone())
|
||||
} else {
|
||||
Err(Error::WithStatus(StatusCode::NOT_FOUND))
|
||||
}
|
||||
}
|
||||
|
||||
async fn set(
|
||||
_req: Request<Body>,
|
||||
query: extract::Query<GetSetQueryString>,
|
||||
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
|
||||
state: extract::Extension<SharedState>,
|
||||
) -> Result<response::Empty, Error> {
|
||||
let state = state.into_inner();
|
||||
let db = &mut state.lock().unwrap().db;
|
||||
|
||||
let key = query.into_inner().key;
|
||||
let value = value.into_inner();
|
||||
|
||||
db.insert(key, value);
|
||||
|
||||
Ok(response::Empty)
|
||||
}
|
14
src/error.rs
14
src/error.rs
|
@ -29,6 +29,15 @@ pub enum Error {
|
|||
|
||||
#[error("request extension of type `{type_name}` was not set")]
|
||||
MissingExtension { type_name: &'static str },
|
||||
|
||||
#[error("`Content-Length` header is missing but was required")]
|
||||
LengthRequired,
|
||||
|
||||
#[error("response body was too large")]
|
||||
PayloadTooLarge,
|
||||
|
||||
#[error("response failed with status {0}")]
|
||||
WithStatus(StatusCode),
|
||||
}
|
||||
|
||||
impl From<Infallible> for Error {
|
||||
|
@ -55,6 +64,11 @@ where
|
|||
| Error::QueryStringMissing
|
||||
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
|
||||
|
||||
Error::WithStatus(status) => make_response(status),
|
||||
|
||||
Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED),
|
||||
Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE),
|
||||
|
||||
Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => {
|
||||
make_response(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use crate::{body::Body, Error};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{future, ready};
|
||||
use http::Request;
|
||||
use http_body::Body as _;
|
||||
use pin_project::pin_project;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{
|
||||
|
@ -128,16 +130,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(david): can we add a length limit somehow? Maybe a const generic?
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Bytes(bytes::Bytes);
|
||||
|
||||
impl Bytes {
|
||||
pub fn into_inner(self) -> bytes::Bytes {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRequest for Bytes {
|
||||
type Future = future::BoxFuture<'static, Result<Self, Error>>;
|
||||
|
||||
|
@ -148,7 +140,44 @@ impl FromRequest for Bytes {
|
|||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeRequestBody)?;
|
||||
Ok(Bytes(bytes))
|
||||
Ok(bytes)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BytesMaxLength<const N: u64>(Bytes);
|
||||
|
||||
impl<const N: u64> BytesMaxLength<N> {
|
||||
pub fn into_inner(self) -> Bytes {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: u64> FromRequest for BytesMaxLength<N> {
|
||||
type Future = future::BoxFuture<'static, Result<Self, Error>>;
|
||||
|
||||
fn from_request(req: &mut Request<Body>) -> Self::Future {
|
||||
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
|
||||
let body = std::mem::take(req.body_mut());
|
||||
|
||||
Box::pin(async move {
|
||||
let content_length =
|
||||
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
|
||||
|
||||
if let Some(length) = content_length {
|
||||
if length > N {
|
||||
return Err(Error::PayloadTooLarge);
|
||||
}
|
||||
} else {
|
||||
return Err(Error::LengthRequired);
|
||||
};
|
||||
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeRequestBody)?;
|
||||
|
||||
Ok(BytesMaxLength(bytes))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
29
src/lib.rs
29
src/lib.rs
|
@ -6,6 +6,10 @@ Improvements to make:
|
|||
|
||||
Support extracting headers, perhaps via `headers::Header`?
|
||||
|
||||
Actual routing
|
||||
|
||||
Improve compile times with lots of routes, can we box and combine routers?
|
||||
|
||||
Tests
|
||||
|
||||
*/
|
||||
|
@ -181,7 +185,10 @@ mod tests {
|
|||
Ok(Response::new(Body::from("Hello, World!")))
|
||||
}
|
||||
|
||||
async fn large_static_file(_: Request<Body>) -> Result<Response<Body>, Error> {
|
||||
async fn large_static_file(
|
||||
_: Request<Body>,
|
||||
body: extract::BytesMaxLength<{ 1024 * 500 }>,
|
||||
) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::empty()))
|
||||
}
|
||||
|
||||
|
@ -309,24 +316,4 @@ mod tests {
|
|||
let bytes = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
String::from_utf8(bytes.to_vec()).unwrap()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
// this should just compile
|
||||
async fn compatible_with_hyper_and_tower_http() {
|
||||
let app = app()
|
||||
.at("/")
|
||||
.get(|_: Request<Body>| async {
|
||||
Ok::<_, Error>(Response::new(Body::from("Hello, World!")))
|
||||
})
|
||||
.into_service();
|
||||
|
||||
let app = ServiceBuilder::new()
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(CompressionLayer::new())
|
||||
.service(app);
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
let server = Server::bind(&addr).serve(Shared::new(app));
|
||||
server.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,33 +25,53 @@ impl IntoResponse<Body> for String {
|
|||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for Bytes {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for &'static [u8] {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for Vec<u8> {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for std::borrow::Cow<'static, str> {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for Bytes {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
let mut res = Response::new(Body::from(self));
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for &'static [u8] {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
let mut res = Response::new(Body::from(self));
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for Vec<u8> {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
let mut res = Response::new(Body::from(self));
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse<Body> for std::borrow::Cow<'static, [u8]> {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from(self)))
|
||||
let mut res = Response::new(Body::from(self));
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -63,17 +83,20 @@ where
|
|||
{
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
let bytes = serde_json::to_vec(&self.0).map_err(Error::SerializeResponseBody)?;
|
||||
let len = bytes.len();
|
||||
let mut res = Response::new(Body::from(bytes));
|
||||
|
||||
res.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
res.headers_mut()
|
||||
.insert(header::CONTENT_LENGTH, HeaderValue::from(len));
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct Empty;
|
||||
|
||||
impl IntoResponse<Body> for Empty {
|
||||
fn into_response(self) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::empty()))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue