1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-03-30 11:19:20 +02:00

Add example

This commit is contained in:
David Pedersen 2021-05-30 14:33:20 +02:00
parent f4268471b6
commit 7328127a3d
6 changed files with 202 additions and 58 deletions

View file

@ -19,8 +19,10 @@ thiserror = "1.0"
tower = { version = "0.4", features = ["util"] }
[dev-dependencies]
tokio = { version = "1.6.1", features = ["macros", "rt"] }
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
serde = { version = "1.0", features = ["derive"] }
tower = { version = "0.4", features = ["util", "make", "timeout"] }
tower-http = { version = "0.1", features = ["trace", "compression", "add-extension"] }
hyper = { version = "0.14", features = ["full"] }
tracing = "0.1"
tracing-subscriber = "0.2"

View file

@ -0,0 +1,89 @@
use bytes::Bytes;
use http::{Request, StatusCode};
use hyper::Server;
use serde::Deserialize;
use std::{
collections::HashMap,
net::SocketAddr,
sync::{Arc, Mutex},
time::Duration,
};
use tower::{make::Shared, ServiceBuilder};
use tower_http::{
add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer,
};
use tower_web::{body::Body, extract, response, Error};
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
// build our application with some routes
let app = tower_web::app()
.at("/get")
.get(get)
.at("/set")
.post(set)
// convert it into a `Service`
.into_service();
// add some middleware
let app = ServiceBuilder::new()
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.layer(AddExtensionLayer::new(SharedState::default()))
.service(app);
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
let server = Server::bind(&addr).serve(Shared::new(app));
server.await.unwrap();
}
type SharedState = Arc<Mutex<State>>;
#[derive(Default)]
struct State {
db: HashMap<String, Bytes>,
}
#[derive(Deserialize)]
struct GetSetQueryString {
key: String,
}
async fn get(
_req: Request<Body>,
query: extract::Query<GetSetQueryString>,
state: extract::Extension<SharedState>,
) -> Result<Bytes, Error> {
let state = state.into_inner();
let db = &state.lock().unwrap().db;
let key = query.into_inner().key;
if let Some(value) = db.get(&key) {
Ok(value.clone())
} else {
Err(Error::WithStatus(StatusCode::NOT_FOUND))
}
}
async fn set(
_req: Request<Body>,
query: extract::Query<GetSetQueryString>,
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
state: extract::Extension<SharedState>,
) -> Result<response::Empty, Error> {
let state = state.into_inner();
let db = &mut state.lock().unwrap().db;
let key = query.into_inner().key;
let value = value.into_inner();
db.insert(key, value);
Ok(response::Empty)
}

View file

@ -29,6 +29,15 @@ pub enum Error {
#[error("request extension of type `{type_name}` was not set")]
MissingExtension { type_name: &'static str },
#[error("`Content-Length` header is missing but was required")]
LengthRequired,
#[error("response body was too large")]
PayloadTooLarge,
#[error("response failed with status {0}")]
WithStatus(StatusCode),
}
impl From<Infallible> for Error {
@ -55,6 +64,11 @@ where
| Error::QueryStringMissing
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
Error::WithStatus(status) => make_response(status),
Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED),
Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE),
Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => {
make_response(StatusCode::INTERNAL_SERVER_ERROR)
}

View file

@ -1,6 +1,8 @@
use crate::{body::Body, Error};
use bytes::Bytes;
use futures_util::{future, ready};
use http::Request;
use http_body::Body as _;
use pin_project::pin_project;
use serde::de::DeserializeOwned;
use std::{
@ -128,16 +130,6 @@ where
}
}
// TODO(david): can we add a length limit somehow? Maybe a const generic?
#[derive(Debug, Clone)]
pub struct Bytes(bytes::Bytes);
impl Bytes {
pub fn into_inner(self) -> bytes::Bytes {
self.0
}
}
impl FromRequest for Bytes {
type Future = future::BoxFuture<'static, Result<Self, Error>>;
@ -148,7 +140,44 @@ impl FromRequest for Bytes {
let bytes = hyper::body::to_bytes(body)
.await
.map_err(Error::ConsumeRequestBody)?;
Ok(Bytes(bytes))
Ok(bytes)
})
}
}
#[derive(Debug, Clone)]
pub struct BytesMaxLength<const N: u64>(Bytes);
impl<const N: u64> BytesMaxLength<N> {
pub fn into_inner(self) -> Bytes {
self.0
}
}
impl<const N: u64> FromRequest for BytesMaxLength<N> {
type Future = future::BoxFuture<'static, Result<Self, Error>>;
fn from_request(req: &mut Request<Body>) -> Self::Future {
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
let body = std::mem::take(req.body_mut());
Box::pin(async move {
let content_length =
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
if let Some(length) = content_length {
if length > N {
return Err(Error::PayloadTooLarge);
}
} else {
return Err(Error::LengthRequired);
};
let bytes = hyper::body::to_bytes(body)
.await
.map_err(Error::ConsumeRequestBody)?;
Ok(BytesMaxLength(bytes))
})
}
}

View file

@ -6,6 +6,10 @@ Improvements to make:
Support extracting headers, perhaps via `headers::Header`?
Actual routing
Improve compile times with lots of routes, can we box and combine routers?
Tests
*/
@ -181,7 +185,10 @@ mod tests {
Ok(Response::new(Body::from("Hello, World!")))
}
async fn large_static_file(_: Request<Body>) -> Result<Response<Body>, Error> {
async fn large_static_file(
_: Request<Body>,
body: extract::BytesMaxLength<{ 1024 * 500 }>,
) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::empty()))
}
@ -309,24 +316,4 @@ mod tests {
let bytes = hyper::body::to_bytes(res.into_body()).await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}
#[allow(dead_code)]
// this should just compile
async fn compatible_with_hyper_and_tower_http() {
let app = app()
.at("/")
.get(|_: Request<Body>| async {
Ok::<_, Error>(Response::new(Body::from("Hello, World!")))
})
.into_service();
let app = ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.service(app);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let server = Server::bind(&addr).serve(Shared::new(app));
server.await.unwrap();
}
}

View file

@ -25,33 +25,53 @@ impl IntoResponse<Body> for String {
}
}
impl IntoResponse<Body> for Bytes {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for &'static [u8] {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for Vec<u8> {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for std::borrow::Cow<'static, str> {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for Bytes {
fn into_response(self) -> Result<Response<Body>, Error> {
let mut res = Response::new(Body::from(self));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
Ok(res)
}
}
impl IntoResponse<Body> for &'static [u8] {
fn into_response(self) -> Result<Response<Body>, Error> {
let mut res = Response::new(Body::from(self));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
Ok(res)
}
}
impl IntoResponse<Body> for Vec<u8> {
fn into_response(self) -> Result<Response<Body>, Error> {
let mut res = Response::new(Body::from(self));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
Ok(res)
}
}
impl IntoResponse<Body> for std::borrow::Cow<'static, [u8]> {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
let mut res = Response::new(Body::from(self));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
Ok(res)
}
}
@ -63,17 +83,20 @@ where
{
fn into_response(self) -> Result<Response<Body>, Error> {
let bytes = serde_json::to_vec(&self.0).map_err(Error::SerializeResponseBody)?;
let len = bytes.len();
let mut res = Response::new(Body::from(bytes));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
res.headers_mut()
.insert(header::CONTENT_LENGTH, HeaderValue::from(len));
Ok(res)
}
}
#[derive(Debug, Copy, Clone)]
pub struct Empty;
impl IntoResponse<Body> for Empty {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::empty()))
}
}