Add Form extractor (#12)

Fixes https://github.com/davidpdrsn/tower-web/issues/2
This commit is contained in:
David Pedersen 2021-06-13 11:01:40 +02:00 committed by GitHub
parent 59944c231f
commit 27ebb3db7a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 178 additions and 8 deletions

54
examples/form.rs Normal file
View file

@ -0,0 +1,54 @@
use http::Request;
use serde::Deserialize;
use std::net::SocketAddr;
use tower_web::prelude::*;
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
// build our application with some routes
let app = route("/", get(show_form).post(accept_form))
.layer(tower_http::trace::TraceLayer::new_for_http());
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
app.serve(&addr).await.unwrap();
}
async fn show_form(_req: Request<Body>) -> response::Html<&'static str> {
response::Html(
r#"
<!doctype html>
<html>
<head></head>
<body>
<form action="/" method="post">
<label for="name">
Enter your name:
<input type="text" name="name">
</label>
<label>
Enter your email:
<input type="text" name="email">
</label>
<input type="submit" value="Subscribe!">
</form>
</body>
</html>
"#,
)
}
#[derive(Deserialize, Debug)]
struct Input {
name: String,
email: String,
}
async fn accept_form(extract::Form(input): extract::Form<Input>) {
dbg!(&input);
}

View file

@ -132,11 +132,12 @@
use crate::{body::Body, response::IntoResponse};
use async_trait::async_trait;
use bytes::Bytes;
use bytes::{Buf, Bytes};
use http::{header, HeaderMap, Method, Request, Response, Uri, Version};
use rejection::{
BodyAlreadyExtracted, FailedToBufferBody, InvalidJsonBody, InvalidUrlParam, InvalidUtf8,
LengthRequired, MissingExtension, MissingJsonContentType, MissingRouteParams, PayloadTooLarge,
BodyAlreadyExtracted, FailedToBufferBody, FailedToDeserializeQueryString,
InvalidFormContentType, InvalidJsonBody, InvalidUrlParam, InvalidUtf8, LengthRequired,
MissingExtension, MissingJsonContentType, MissingRouteParams, PayloadTooLarge,
QueryStringMissing, RequestAlreadyExtracted, UrlParamsAlreadyExtracted,
};
use serde::de::DeserializeOwned;
@ -192,10 +193,11 @@ where
///
/// // ...
/// }
///
/// let app = route("/list_things", get(list_things));
/// ```
///
/// If the query string cannot be parsed it will reject the request with a `404
/// If the query string cannot be parsed it will reject the request with a `400
/// Bad Request` response.
#[derive(Debug, Clone, Copy, Default)]
pub struct Query<T>(pub T);
@ -205,15 +207,89 @@ impl<T> FromRequest for Query<T>
where
T: DeserializeOwned,
{
type Rejection = QueryStringMissing;
type Rejection = Response<Body>;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().ok_or(QueryStringMissing)?;
let value = serde_urlencoded::from_str(query).map_err(|_| QueryStringMissing)?;
let query = req
.uri()
.query()
.ok_or(QueryStringMissing)
.map_err(IntoResponse::into_response)?;
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::new::<T, _>)
.map_err(IntoResponse::into_response)?;
Ok(Query(value))
}
}
/// Extractor that deserializes `application/x-www-form-urlencoded` requests
/// into some type.
///
/// `T` is expected to implement [`serde::Deserialize`].
///
/// # Example
///
/// ```rust,no_run
/// use tower_web::prelude::*;
/// use serde::Deserialize;
///
/// #[derive(Deserialize)]
/// struct SignUp {
/// username: String,
/// password: String,
/// }
///
/// async fn accept_form(form: extract::Form<SignUp>) {
/// let sign_up: SignUp = form.0;
///
/// // ...
/// }
///
/// let app = route("/sign_up", post(accept_form));
/// ```
///
/// Note that `Content-Type: multipart/form-data` requests are not supported.
#[derive(Debug, Clone, Copy, Default)]
pub struct Form<T>(pub T);
#[async_trait]
impl<T> FromRequest for Form<T>
where
T: DeserializeOwned,
{
type Rejection = Response<Body>;
#[allow(warnings)]
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
if !has_content_type(&req, "application/x-www-form-urlencoded") {
return Err(InvalidFormContentType.into_response());
}
if req.method() == Method::GET {
let query = req
.uri()
.query()
.ok_or(QueryStringMissing)
.map_err(IntoResponse::into_response)?;
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::new::<T, _>)
.map_err(IntoResponse::into_response)?;
Ok(Form(value))
} else {
let body = take_body(req).map_err(IntoResponse::into_response)?;
let chunks = hyper::body::aggregate(body)
.await
.map_err(FailedToBufferBody::from_err)
.map_err(IntoResponse::into_response)?;
let value = serde_urlencoded::from_reader(chunks.reader())
.map_err(FailedToDeserializeQueryString::new::<T, _>)
.map_err(IntoResponse::into_response)?;
Ok(Form(value))
}
}
}
/// Extractor that deserializes request bodies into some type.
///
/// `T` is expected to implement [`serde::Deserialize`].
@ -239,7 +315,7 @@ where
/// let app = route("/users", post(create_user));
/// ```
///
/// If the query string cannot be parsed it will reject the request with a `404
/// If the query string cannot be parsed it will reject the request with a `400
/// Bad Request` response.
///
/// The request is required to have a `Content-Type: application/json` header.

View file

@ -1,5 +1,7 @@
//! Rejection response types.
use tower::BoxError;
use super::IntoResponse;
use crate::body::Body;
@ -147,6 +149,13 @@ define_rejection! {
pub struct RequestAlreadyExtracted;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Form requests must have `Content-Type: x-www-form-urlencoded`"]
/// Rejection type used if you try and extract the request more than once.
pub struct InvalidFormContentType;
}
/// Rejection type for [`UrlParams`](super::UrlParams) if the capture route
/// param didn't have the expected type.
#[derive(Debug)]
@ -172,3 +181,34 @@ impl IntoResponse for InvalidUrlParam {
res
}
}
/// Rejection type for extractors that deserialize query strings if the input
/// couldn't be deserialized into the target type.
#[derive(Debug)]
pub struct FailedToDeserializeQueryString {
error: BoxError,
type_name: &'static str,
}
impl FailedToDeserializeQueryString {
pub(super) fn new<T, E>(error: E) -> Self
where
E: Into<BoxError>,
{
FailedToDeserializeQueryString {
error: error.into(),
type_name: std::any::type_name::<T>(),
}
}
}
impl IntoResponse for FailedToDeserializeQueryString {
fn into_response(self) -> http::Response<Body> {
let mut res = http::Response::new(Body::from(format!(
"Failed to deserialize query string. Expected something of type `{}`. Error: {}",
self.type_name, self.error,
)));
*res.status_mut() = http::StatusCode::BAD_REQUEST;
res
}
}