mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-31 11:49:55 +02:00
Start writing more tests
This commit is contained in:
parent
6822766165
commit
593c901aab
5 changed files with 259 additions and 196 deletions
|
@ -19,10 +19,11 @@ thiserror = "1.0"
|
|||
tower = { version = "0.4", features = ["util"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
reqwest = { version = "0.11", features = ["json", "stream"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
tower = { version = "0.4", features = ["util", "make", "timeout"] }
|
||||
tower-http = { version = "0.1", features = ["trace", "compression", "add-extension"] }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.2"
|
||||
|
|
|
@ -44,6 +44,9 @@ pub enum Error {
|
|||
|
||||
#[error("unknown URL param `{0}`")]
|
||||
UnknownUrlParam(String),
|
||||
|
||||
#[error("response body didn't contain valid UTF-8")]
|
||||
InvalidUtf8,
|
||||
}
|
||||
|
||||
impl From<Infallible> for Error {
|
||||
|
@ -69,7 +72,8 @@ where
|
|||
Error::DeserializeRequestBody(_)
|
||||
| Error::QueryStringMissing
|
||||
| Error::DeserializeQueryString(_)
|
||||
| Error::InvalidUrlParam { .. } => make_response(StatusCode::BAD_REQUEST),
|
||||
| Error::InvalidUrlParam { .. }
|
||||
| Error::InvalidUtf8 => make_response(StatusCode::BAD_REQUEST),
|
||||
|
||||
Error::Status(status) => make_response(status),
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{body::Body, Error};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{future, ready};
|
||||
use http::Request;
|
||||
use http::{header, Request, StatusCode};
|
||||
use pin_project::pin_project;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{
|
||||
|
@ -86,20 +86,39 @@ where
|
|||
type Future = future::BoxFuture<'static, Result<Self, Error>>;
|
||||
|
||||
fn from_request(req: &mut Request<Body>) -> Self::Future {
|
||||
// TODO(david): require the body to have `content-type: application/json`
|
||||
if has_content_type(&req, "application/json") {
|
||||
let body = std::mem::take(req.body_mut());
|
||||
|
||||
let body = std::mem::take(req.body_mut());
|
||||
|
||||
Box::pin(async move {
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeRequestBody)?;
|
||||
let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?;
|
||||
Ok(Json(value))
|
||||
})
|
||||
Box::pin(async move {
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeRequestBody)?;
|
||||
let value =
|
||||
serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?;
|
||||
Ok(Json(value))
|
||||
})
|
||||
} else {
|
||||
Box::pin(async { Err(Error::Status(StatusCode::BAD_REQUEST)) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn has_content_type<B>(req: &Request<B>, expected_content_type: &str) -> bool {
|
||||
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
|
||||
content_type
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let content_type = if let Ok(content_type) = content_type.to_str() {
|
||||
content_type
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
content_type.starts_with(expected_content_type)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Extension<T>(T);
|
||||
|
||||
|
@ -146,6 +165,32 @@ impl FromRequest for Bytes {
|
|||
}
|
||||
}
|
||||
|
||||
impl FromRequest for String {
|
||||
type Future = future::BoxFuture<'static, Result<Self, Error>>;
|
||||
|
||||
fn from_request(req: &mut Request<Body>) -> Self::Future {
|
||||
let body = std::mem::take(req.body_mut());
|
||||
|
||||
Box::pin(async move {
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(Error::ConsumeRequestBody)?
|
||||
.to_vec();
|
||||
let string = String::from_utf8(bytes).map_err(|_| Error::InvalidUtf8)?;
|
||||
Ok(string)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRequest for Body {
|
||||
type Future = future::Ready<Result<Self, Error>>;
|
||||
|
||||
fn from_request(req: &mut Request<Body>) -> Self::Future {
|
||||
let body = std::mem::take(req.body_mut());
|
||||
future::ok(body)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BytesMaxLength<const N: u64>(Bytes);
|
||||
|
||||
|
|
185
src/lib.rs
185
src/lib.rs
|
@ -1,15 +1,3 @@
|
|||
/*
|
||||
|
||||
Improvements to make:
|
||||
|
||||
Support extracting headers, perhaps via `headers::Header`?
|
||||
|
||||
Improve compile times with lots of routes, can we box and combine routers?
|
||||
|
||||
Tests
|
||||
|
||||
*/
|
||||
|
||||
use self::{
|
||||
body::Body,
|
||||
routing::{EmptyRouter, RouteAt},
|
||||
|
@ -33,6 +21,9 @@ pub mod routing;
|
|||
|
||||
mod error;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use self::error::Error;
|
||||
|
||||
pub fn app() -> App<EmptyRouter> {
|
||||
|
@ -139,173 +130,3 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(warnings)]
|
||||
use super::*;
|
||||
use crate::handler::Handler;
|
||||
use http::{Method, Request, StatusCode};
|
||||
use hyper::Server;
|
||||
use serde::Deserialize;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, net::SocketAddr, sync::Arc};
|
||||
use tower::{
|
||||
layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder,
|
||||
ServiceExt,
|
||||
};
|
||||
use tower_http::{
|
||||
add_extension::AddExtensionLayer,
|
||||
compression::CompressionLayer,
|
||||
trace::{Trace, TraceLayer},
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn basic() {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Pagination {
|
||||
page: usize,
|
||||
per_page: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsersCreate {
|
||||
username: String,
|
||||
}
|
||||
|
||||
async fn root(_: Request<Body>) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::from("Hello, World!")))
|
||||
}
|
||||
|
||||
async fn large_static_file(
|
||||
_: Request<Body>,
|
||||
body: extract::BytesMaxLength<{ 1024 * 500 }>,
|
||||
) -> Result<Response<Body>, Error> {
|
||||
Ok(Response::new(Body::empty()))
|
||||
}
|
||||
|
||||
let app = app()
|
||||
// routes with functions
|
||||
.at("/")
|
||||
.get(root)
|
||||
// routes with closures
|
||||
.at("/users")
|
||||
.get(
|
||||
|_: Request<Body>, pagination: extract::Query<Pagination>| async {
|
||||
let pagination = pagination.into_inner();
|
||||
assert_eq!(pagination.page, 1);
|
||||
assert_eq!(pagination.per_page, 30);
|
||||
Ok::<_, Error>("users#index".to_string())
|
||||
},
|
||||
)
|
||||
.post(
|
||||
|_: Request<Body>,
|
||||
payload: extract::Json<UsersCreate>,
|
||||
_state: extract::Extension<Arc<State>>| async {
|
||||
let payload = payload.into_inner();
|
||||
assert_eq!(payload.username, "bob");
|
||||
Ok::<_, Error>(response::Json(
|
||||
serde_json::json!({ "username": payload.username }),
|
||||
))
|
||||
},
|
||||
)
|
||||
// routes with a service
|
||||
.at("/service")
|
||||
.get_service(service_fn(root))
|
||||
// routes with layers applied
|
||||
.at("/large-static-file")
|
||||
.get(
|
||||
large_static_file.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(30)))
|
||||
.layer(CompressionLayer::new())
|
||||
.into_inner(),
|
||||
),
|
||||
)
|
||||
.into_service();
|
||||
|
||||
// state shared by all routes, could hold db connection etc
|
||||
struct State {}
|
||||
|
||||
let state = Arc::new(State {});
|
||||
|
||||
// can add more middleware
|
||||
let mut app = ServiceBuilder::new()
|
||||
.layer(AddExtensionLayer::new(state))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.service(app);
|
||||
|
||||
let res = app
|
||||
.ready()
|
||||
.await
|
||||
.unwrap()
|
||||
.call(
|
||||
Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(body_to_string(res).await, "Hello, World!");
|
||||
|
||||
let res = app
|
||||
.ready()
|
||||
.await
|
||||
.unwrap()
|
||||
.call(
|
||||
Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/users?page=1&per_page=30")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(body_to_string(res).await, "users#index");
|
||||
|
||||
let res = app
|
||||
.ready()
|
||||
.await
|
||||
.unwrap()
|
||||
.call(
|
||||
Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri("/users")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
assert_eq!(body_to_string(res).await, "");
|
||||
|
||||
let res = app
|
||||
.ready()
|
||||
.await
|
||||
.unwrap()
|
||||
.call(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri("/users")
|
||||
.body(Body::from(r#"{ "username": "bob" }"#))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(body_to_string(res).await, r#"{"username":"bob"}"#);
|
||||
}
|
||||
|
||||
async fn body_to_string<B>(res: Response<B>) -> String
|
||||
where
|
||||
B: http_body::Body,
|
||||
B::Error: fmt::Debug,
|
||||
{
|
||||
let bytes = hyper::body::to_bytes(res.into_body()).await.unwrap();
|
||||
String::from_utf8(bytes.to_vec()).unwrap()
|
||||
}
|
||||
}
|
||||
|
|
192
src/tests.rs
Normal file
192
src/tests.rs
Normal file
|
@ -0,0 +1,192 @@
|
|||
use crate::{app, extract, response};
|
||||
use http::{Request, Response, StatusCode};
|
||||
use hyper::{Body, Server};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::net::{SocketAddr, TcpListener};
|
||||
use tower::{make::Shared, BoxError, Service};
|
||||
|
||||
#[tokio::test]
|
||||
async fn hello_world() {
|
||||
let app = app()
|
||||
.at("/")
|
||||
.get(|_: Request<Body>| async { Ok("Hello, World!") })
|
||||
.into_service();
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
|
||||
let body = res.text().await.unwrap();
|
||||
|
||||
assert_eq!(body, "Hello, World!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn consume_body() {
|
||||
let app = app()
|
||||
.at("/")
|
||||
.get(|_: Request<Body>, body: String| async { Ok(body) })
|
||||
.into_service();
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let res = client
|
||||
.get(format!("http://{}", addr))
|
||||
.body("foo")
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let body = res.text().await.unwrap();
|
||||
|
||||
assert_eq!(body, "foo");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deserialize_body() {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Input {
|
||||
foo: String,
|
||||
}
|
||||
|
||||
let app = app()
|
||||
.at("/")
|
||||
.post(|_: Request<Body>, input: extract::Json<Input>| async { Ok(input.into_inner().foo) })
|
||||
.into_service();
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let res = client
|
||||
.post(format!("http://{}", addr))
|
||||
.json(&json!({ "foo": "bar" }))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let body = res.text().await.unwrap();
|
||||
|
||||
assert_eq!(body, "bar");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn consume_body_to_json_requires_json_content_type() {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Input {
|
||||
foo: String,
|
||||
}
|
||||
|
||||
let app = app()
|
||||
.at("/")
|
||||
.post(|_: Request<Body>, input: extract::Json<Input>| async {
|
||||
let input = input.into_inner();
|
||||
Ok(input.foo)
|
||||
})
|
||||
.into_service();
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let res = client
|
||||
.post(format!("http://{}", addr))
|
||||
.body(r#"{ "foo": "bar" }"#)
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// TODO(david): is this the most appropriate response code?
|
||||
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn body_with_length_limit() {
|
||||
use std::iter::repeat;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Input {
|
||||
foo: String,
|
||||
}
|
||||
|
||||
const LIMIT: u64 = 8;
|
||||
|
||||
let app = app()
|
||||
.at("/")
|
||||
.post(
|
||||
|req: Request<Body>, _body: extract::BytesMaxLength<LIMIT>| async move {
|
||||
dbg!(&req);
|
||||
Ok(response::Empty)
|
||||
},
|
||||
)
|
||||
.into_service();
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.post(format!("http://{}", addr))
|
||||
.body(repeat(0_u8).take((LIMIT - 1) as usize).collect::<Vec<_>>())
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client
|
||||
.post(format!("http://{}", addr))
|
||||
.body(repeat(0_u8).take(LIMIT as usize).collect::<Vec<_>>())
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client
|
||||
.post(format!("http://{}", addr))
|
||||
.body(repeat(0_u8).take((LIMIT + 1) as usize).collect::<Vec<_>>())
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
|
||||
|
||||
let res = client
|
||||
.post(format!("http://{}", addr))
|
||||
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
|
||||
vec![Ok::<_, std::io::Error>(bytes::Bytes::new())],
|
||||
)))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED);
|
||||
}
|
||||
|
||||
// TODO(david): can extractors change the request type?
|
||||
// TODO(david): should FromRequest be an async-trait?
|
||||
|
||||
// TODO(david): routing
|
||||
|
||||
// TODO(david): lots of routes and boxing, shouldn't take forever to compile
|
||||
|
||||
/// Run a `tower::Service` in the background and get a URI for it.
|
||||
pub async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
|
||||
ResBody: http_body::Body + Send + 'static,
|
||||
ResBody::Data: Send,
|
||||
ResBody::Error: Into<BoxError>,
|
||||
S::Error: Into<BoxError>,
|
||||
S::Future: Send,
|
||||
{
|
||||
let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket");
|
||||
let addr = listener.local_addr().unwrap();
|
||||
println!("Listening on {}", addr);
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let server = Server::from_tcp(listener).unwrap().serve(Shared::new(svc));
|
||||
tx.send(()).unwrap();
|
||||
server.await.expect("server error");
|
||||
});
|
||||
|
||||
rx.await.unwrap();
|
||||
|
||||
addr
|
||||
}
|
Loading…
Add table
Reference in a new issue