1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-03-31 11:49:55 +02:00

Start writing more tests

This commit is contained in:
David Pedersen 2021-05-31 12:22:16 +02:00
parent 6822766165
commit 593c901aab
5 changed files with 259 additions and 196 deletions

View file

@ -19,10 +19,11 @@ thiserror = "1.0"
tower = { version = "0.4", features = ["util"] }
[dev-dependencies]
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
hyper = { version = "0.14", features = ["full"] }
reqwest = { version = "0.11", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
tower = { version = "0.4", features = ["util", "make", "timeout"] }
tower-http = { version = "0.1", features = ["trace", "compression", "add-extension"] }
hyper = { version = "0.14", features = ["full"] }
tracing = "0.1"
tracing-subscriber = "0.2"

View file

@ -44,6 +44,9 @@ pub enum Error {
#[error("unknown URL param `{0}`")]
UnknownUrlParam(String),
#[error("response body didn't contain valid UTF-8")]
InvalidUtf8,
}
impl From<Infallible> for Error {
@ -69,7 +72,8 @@ where
Error::DeserializeRequestBody(_)
| Error::QueryStringMissing
| Error::DeserializeQueryString(_)
| Error::InvalidUrlParam { .. } => make_response(StatusCode::BAD_REQUEST),
| Error::InvalidUrlParam { .. }
| Error::InvalidUtf8 => make_response(StatusCode::BAD_REQUEST),
Error::Status(status) => make_response(status),

View file

@ -1,7 +1,7 @@
use crate::{body::Body, Error};
use bytes::Bytes;
use futures_util::{future, ready};
use http::Request;
use http::{header, Request, StatusCode};
use pin_project::pin_project;
use serde::de::DeserializeOwned;
use std::{
@ -86,20 +86,39 @@ where
type Future = future::BoxFuture<'static, Result<Self, Error>>;
fn from_request(req: &mut Request<Body>) -> Self::Future {
// TODO(david): require the body to have `content-type: application/json`
if has_content_type(&req, "application/json") {
let body = std::mem::take(req.body_mut());
let body = std::mem::take(req.body_mut());
Box::pin(async move {
let bytes = hyper::body::to_bytes(body)
.await
.map_err(Error::ConsumeRequestBody)?;
let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?;
Ok(Json(value))
})
Box::pin(async move {
let bytes = hyper::body::to_bytes(body)
.await
.map_err(Error::ConsumeRequestBody)?;
let value =
serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?;
Ok(Json(value))
})
} else {
Box::pin(async { Err(Error::Status(StatusCode::BAD_REQUEST)) })
}
}
}
fn has_content_type<B>(req: &Request<B>, expected_content_type: &str) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type
} else {
return false;
};
let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return false;
};
content_type.starts_with(expected_content_type)
}
#[derive(Debug, Clone, Copy)]
pub struct Extension<T>(T);
@ -146,6 +165,32 @@ impl FromRequest for Bytes {
}
}
impl FromRequest for String {
type Future = future::BoxFuture<'static, Result<Self, Error>>;
fn from_request(req: &mut Request<Body>) -> Self::Future {
let body = std::mem::take(req.body_mut());
Box::pin(async move {
let bytes = hyper::body::to_bytes(body)
.await
.map_err(Error::ConsumeRequestBody)?
.to_vec();
let string = String::from_utf8(bytes).map_err(|_| Error::InvalidUtf8)?;
Ok(string)
})
}
}
impl FromRequest for Body {
type Future = future::Ready<Result<Self, Error>>;
fn from_request(req: &mut Request<Body>) -> Self::Future {
let body = std::mem::take(req.body_mut());
future::ok(body)
}
}
#[derive(Debug, Clone)]
pub struct BytesMaxLength<const N: u64>(Bytes);

View file

@ -1,15 +1,3 @@
/*
Improvements to make:
Support extracting headers, perhaps via `headers::Header`?
Improve compile times with lots of routes, can we box and combine routers?
Tests
*/
use self::{
body::Body,
routing::{EmptyRouter, RouteAt},
@ -33,6 +21,9 @@ pub mod routing;
mod error;
#[cfg(test)]
mod tests;
pub use self::error::Error;
pub fn app() -> App<EmptyRouter> {
@ -139,173 +130,3 @@ where
}
}
}
#[cfg(test)]
mod tests {
#![allow(warnings)]
use super::*;
use crate::handler::Handler;
use http::{Method, Request, StatusCode};
use hyper::Server;
use serde::Deserialize;
use std::time::Duration;
use std::{fmt, net::SocketAddr, sync::Arc};
use tower::{
layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder,
ServiceExt,
};
use tower_http::{
add_extension::AddExtensionLayer,
compression::CompressionLayer,
trace::{Trace, TraceLayer},
};
#[tokio::test]
async fn basic() {
#[derive(Debug, Deserialize)]
struct Pagination {
page: usize,
per_page: usize,
}
#[derive(Debug, Deserialize)]
struct UsersCreate {
username: String,
}
async fn root(_: Request<Body>) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from("Hello, World!")))
}
async fn large_static_file(
_: Request<Body>,
body: extract::BytesMaxLength<{ 1024 * 500 }>,
) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::empty()))
}
let app = app()
// routes with functions
.at("/")
.get(root)
// routes with closures
.at("/users")
.get(
|_: Request<Body>, pagination: extract::Query<Pagination>| async {
let pagination = pagination.into_inner();
assert_eq!(pagination.page, 1);
assert_eq!(pagination.per_page, 30);
Ok::<_, Error>("users#index".to_string())
},
)
.post(
|_: Request<Body>,
payload: extract::Json<UsersCreate>,
_state: extract::Extension<Arc<State>>| async {
let payload = payload.into_inner();
assert_eq!(payload.username, "bob");
Ok::<_, Error>(response::Json(
serde_json::json!({ "username": payload.username }),
))
},
)
// routes with a service
.at("/service")
.get_service(service_fn(root))
// routes with layers applied
.at("/large-static-file")
.get(
large_static_file.layer(
ServiceBuilder::new()
.layer(TimeoutLayer::new(Duration::from_secs(30)))
.layer(CompressionLayer::new())
.into_inner(),
),
)
.into_service();
// state shared by all routes, could hold db connection etc
struct State {}
let state = Arc::new(State {});
// can add more middleware
let mut app = ServiceBuilder::new()
.layer(AddExtensionLayer::new(state))
.layer(TraceLayer::new_for_http())
.service(app);
let res = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::GET)
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(body_to_string(res).await, "Hello, World!");
let res = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::GET)
.uri("/users?page=1&per_page=30")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(body_to_string(res).await, "users#index");
let res = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::GET)
.uri("/users")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(body_to_string(res).await, "");
let res = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::POST)
.uri("/users")
.body(Body::from(r#"{ "username": "bob" }"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(body_to_string(res).await, r#"{"username":"bob"}"#);
}
async fn body_to_string<B>(res: Response<B>) -> String
where
B: http_body::Body,
B::Error: fmt::Debug,
{
let bytes = hyper::body::to_bytes(res.into_body()).await.unwrap();
String::from_utf8(bytes.to_vec()).unwrap()
}
}

192
src/tests.rs Normal file
View file

@ -0,0 +1,192 @@
use crate::{app, extract, response};
use http::{Request, Response, StatusCode};
use hyper::{Body, Server};
use serde::Deserialize;
use serde_json::json;
use std::net::{SocketAddr, TcpListener};
use tower::{make::Shared, BoxError, Service};
#[tokio::test]
async fn hello_world() {
let app = app()
.at("/")
.get(|_: Request<Body>| async { Ok("Hello, World!") })
.into_service();
let addr = run_in_background(app).await;
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "Hello, World!");
}
#[tokio::test]
async fn consume_body() {
let app = app()
.at("/")
.get(|_: Request<Body>, body: String| async { Ok(body) })
.into_service();
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}", addr))
.body("foo")
.send()
.await
.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "foo");
}
#[tokio::test]
async fn deserialize_body() {
#[derive(Debug, Deserialize)]
struct Input {
foo: String,
}
let app = app()
.at("/")
.post(|_: Request<Body>, input: extract::Json<Input>| async { Ok(input.into_inner().foo) })
.into_service();
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.post(format!("http://{}", addr))
.json(&json!({ "foo": "bar" }))
.send()
.await
.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "bar");
}
#[tokio::test]
async fn consume_body_to_json_requires_json_content_type() {
#[derive(Debug, Deserialize)]
struct Input {
foo: String,
}
let app = app()
.at("/")
.post(|_: Request<Body>, input: extract::Json<Input>| async {
let input = input.into_inner();
Ok(input.foo)
})
.into_service();
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.post(format!("http://{}", addr))
.body(r#"{ "foo": "bar" }"#)
.send()
.await
.unwrap();
// TODO(david): is this the most appropriate response code?
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn body_with_length_limit() {
use std::iter::repeat;
#[derive(Debug, Deserialize)]
struct Input {
foo: String,
}
const LIMIT: u64 = 8;
let app = app()
.at("/")
.post(
|req: Request<Body>, _body: extract::BytesMaxLength<LIMIT>| async move {
dbg!(&req);
Ok(response::Empty)
},
)
.into_service();
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.post(format!("http://{}", addr))
.body(repeat(0_u8).take((LIMIT - 1) as usize).collect::<Vec<_>>())
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post(format!("http://{}", addr))
.body(repeat(0_u8).take(LIMIT as usize).collect::<Vec<_>>())
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post(format!("http://{}", addr))
.body(repeat(0_u8).take((LIMIT + 1) as usize).collect::<Vec<_>>())
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
let res = client
.post(format!("http://{}", addr))
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
vec![Ok::<_, std::io::Error>(bytes::Bytes::new())],
)))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED);
}
// TODO(david): can extractors change the request type?
// TODO(david): should FromRequest be an async-trait?
// TODO(david): routing
// TODO(david): lots of routes and boxing, shouldn't take forever to compile
/// Run a `tower::Service` in the background and get a URI for it.
pub async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
ResBody: http_body::Body + Send + 'static,
ResBody::Data: Send,
ResBody::Error: Into<BoxError>,
S::Error: Into<BoxError>,
S::Future: Send,
{
let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket");
let addr = listener.local_addr().unwrap();
println!("Listening on {}", addr);
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let server = Server::from_tcp(listener).unwrap().serve(Shared::new(svc));
tx.send(()).unwrap();
server.await.expect("server error");
});
rx.await.unwrap();
addr
}