Fix Query and Form extractors giving bad request error when query string is empty (#117)

Co-Authored-By: David Pedersen <david.pdrsn@gmail.com>

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Sunli 2021-08-04 23:13:09 +08:00 committed by GitHub
parent 5c12328892
commit fb0b3b78eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 175 additions and 6 deletions

View file

@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add support for WebSocket protocol negotiation. ([#83](https://github.com/tokio-rs/axum/pull/83))
- Use `pin-project-lite` instead of `pin-project`. ([#95](https://github.com/tokio-rs/axum/pull/95))
- Re-export `http` crate and `hyper::Server`. ([#110](https://github.com/tokio-rs/axum/pull/110))
- Fix `Query` and `Form` extractors giving bad request error when query string is empty. ([#117](https://github.com/tokio-rs/axum/pull/117))
## Breaking changes

View file

@ -50,20 +50,20 @@ where
#[allow(warnings)]
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if !has_content_type(&req, "application/x-www-form-urlencoded")? {
Err(InvalidFormContentType)?;
}
if req.method().ok_or(MethodAlreadyExtracted)? == Method::GET {
let query = req
.uri()
.ok_or(UriAlreadyExtracted)?
.query()
.ok_or(QueryStringMissing)?;
.unwrap_or_default();
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
Ok(Form(value))
} else {
if !has_content_type(&req, "application/x-www-form-urlencoded")? {
Err(InvalidFormContentType)?;
}
let body = take_body(req)?;
let chunks = hyper::body::aggregate(body)
.await
@ -83,3 +83,121 @@ impl<T> Deref for Form<T> {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::extract::RequestParts;
use http::Request;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Pagination {
size: Option<u64>,
page: Option<u64>,
}
async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
let mut req = RequestParts::new(
Request::builder()
.uri(uri.as_ref())
.body(http_body::Empty::<bytes::Bytes>::new())
.unwrap(),
);
assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value);
}
async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
let mut req = RequestParts::new(
Request::builder()
.uri("http://example.com/test")
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
)
.body(http_body::Full::<bytes::Bytes>::new(
serde_urlencoded::to_string(&value).unwrap().into(),
))
.unwrap(),
);
assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value);
}
#[tokio::test]
async fn test_form_query() {
check_query(
"http://example.com/test",
Pagination {
size: None,
page: None,
},
)
.await;
check_query(
"http://example.com/test?size=10",
Pagination {
size: Some(10),
page: None,
},
)
.await;
check_query(
"http://example.com/test?size=10&page=20",
Pagination {
size: Some(10),
page: Some(20),
},
)
.await;
}
#[tokio::test]
async fn test_form_body() {
check_body(Pagination {
size: None,
page: None,
})
.await;
check_body(Pagination {
size: Some(10),
page: None,
})
.await;
check_body(Pagination {
size: Some(10),
page: Some(20),
})
.await;
}
#[tokio::test]
async fn test_incorrect_content_type() {
let mut req = RequestParts::new(
Request::builder()
.uri("http://example.com/test")
.method(Method::POST)
.header(http::header::CONTENT_TYPE, "application/json")
.body(http_body::Full::<bytes::Bytes>::new(
serde_urlencoded::to_string(&Pagination {
size: Some(10),
page: None,
})
.unwrap()
.into(),
))
.unwrap(),
);
assert!(matches!(
Form::<Pagination>::from_request(&mut req)
.await
.unwrap_err(),
FormRejection::InvalidFormContentType(InvalidFormContentType)
));
}
}

View file

@ -51,7 +51,7 @@ where
.uri()
.ok_or(UriAlreadyExtracted)?
.query()
.ok_or(QueryStringMissing)?;
.unwrap_or_default();
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
Ok(Query(value))
@ -65,3 +65,53 @@ impl<T> Deref for Query<T> {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::extract::RequestParts;
use http::Request;
use serde::Deserialize;
use std::fmt::Debug;
async fn check<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
let mut req = RequestParts::new(Request::builder().uri(uri.as_ref()).body(()).unwrap());
assert_eq!(Query::<T>::from_request(&mut req).await.unwrap().0, value);
}
#[tokio::test]
async fn test_query() {
#[derive(Debug, PartialEq, Deserialize)]
struct Pagination {
size: Option<u64>,
page: Option<u64>,
}
check(
"http://example.com/test",
Pagination {
size: None,
page: None,
},
)
.await;
check(
"http://example.com/test?size=10",
Pagination {
size: Some(10),
page: None,
},
)
.await;
check(
"http://example.com/test?size=10&page=20",
Pagination {
size: Some(10),
page: Some(20),
},
)
.await;
}
}