From 867dd8012ced4a6b261cde927ff6b2b44465bc3c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 31 May 2021 14:04:05 +0200 Subject: [PATCH] Add some more tests --- src/extract.rs | 21 ++++++++--- src/tests.rs | 94 ++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 107 insertions(+), 8 deletions(-) diff --git a/src/extract.rs b/src/extract.rs index e516ed4e..3a9f7d53 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -10,6 +10,17 @@ pub trait FromRequest: Sized { async fn from_request(req: &mut Request) -> Result; } +fn take_body(req: &mut Request) -> Body { + struct BodyAlreadyTaken; + + if req.extensions_mut().insert(BodyAlreadyTaken).is_some() { + panic!("Cannot have two request body on extractors") + } else { + let body = std::mem::take(req.body_mut()); + body + } +} + #[async_trait] impl FromRequest for Option where @@ -57,7 +68,7 @@ where { async fn from_request(req: &mut Request) -> Result { if has_content_type(&req, "application/json") { - let body = std::mem::take(req.body_mut()); + let body = take_body(req); let bytes = hyper::body::to_bytes(body) .await @@ -116,7 +127,7 @@ where #[async_trait] impl FromRequest for Bytes { async fn from_request(req: &mut Request) -> Result { - let body = std::mem::take(req.body_mut()); + let body = take_body(req); let bytes = hyper::body::to_bytes(body) .await @@ -129,7 +140,7 @@ impl FromRequest for Bytes { #[async_trait] impl FromRequest for String { async fn from_request(req: &mut Request) -> Result { - let body = std::mem::take(req.body_mut()); + let body = take_body(req); let bytes = hyper::body::to_bytes(body) .await @@ -145,7 +156,7 @@ impl FromRequest for String { #[async_trait] impl FromRequest for Body { async fn from_request(req: &mut Request) -> Result { - let body = std::mem::take(req.body_mut()); + let body = take_body(req); Ok(body) } } @@ -163,7 +174,7 @@ impl BytesMaxLength { impl FromRequest for BytesMaxLength { async fn from_request(req: &mut Request) -> Result { let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned(); - let body = std::mem::take(req.body_mut()); + let body = take_body(req); let content_length = content_length.and_then(|value| value.to_str().ok()?.parse::().ok()); diff --git a/src/tests.rs b/src/tests.rs index a0924932..5a1012d7 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -157,10 +157,98 @@ async fn body_with_length_limit() { assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED); } -// TODO(david): can extractors change the request type? -// TODO(david): should FromRequest be an async-trait? +#[tokio::test] +async fn routing() { + let app = app() + .at("/users") + .get(|_: Request| async { Ok("users#index") }) + .post(|_: Request| async { Ok("users#create") }) + .at("/users/:id") + .get(|_: Request| async { Ok("users#show") }) + .at("/users/:id/action") + .get(|_: Request| async { Ok("users#action") }) + .into_service(); -// TODO(david): routing + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client.get(format!("http://{}", addr)).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let res = client + .get(format!("http://{}/users", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "users#index"); + + let res = client + .post(format!("http://{}/users", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "users#create"); + + let res = client + .get(format!("http://{}/users/1", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "users#show"); + + let res = client + .get(format!("http://{}/users/1/action", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "users#action"); +} + +#[tokio::test] +async fn extracting_url_params() { + let app = app() + .at("/users/:id") + .get( + |_: Request, params: extract::UrlParams<(i32,)>| async move { + let (id,) = params.into_inner(); + assert_eq!(id, 42); + + Ok(response::Empty) + }, + ) + .post( + |_: Request, params_map: extract::UrlParamsMap| async move { + assert_eq!(params_map.get("id").unwrap(), "1337"); + assert_eq!(params_map.get_typed::("id").unwrap(), 1337); + + Ok(response::Empty) + }, + ) + .into_service(); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/users/42", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post(format!("http://{}/users/1337", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} // TODO(david): lots of routes and boxing, shouldn't take forever to compile