Clean up RequestParts API (#167)

In http-body 0.4.3 `BoxBody` implements `Default`. This allows us to
clean up the API of `RequestParts` quite a bit.
This commit is contained in:
David Pedersen 2021-08-08 19:48:30 +02:00 committed by GitHub
parent bc27b09f5c
commit 6b218c7150
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 112 additions and 38 deletions

View file

@ -30,6 +30,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
- `ServiceExt` has been removed and its methods have been moved to `RoutingDsl` ([#160](https://github.com/tokio-rs/axum/pull/160))
- `extractor_middleware` now requires `RequestBody: Default` ([#167](https://github.com/tokio-rs/axum/pull/167))
- Convert `RequestAlreadyExtracted` to an enum with each possible error variant ([#167](https://github.com/tokio-rs/axum/pull/167))
- These future types have been moved
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))

View file

@ -22,7 +22,7 @@ bitflags = "1.0"
bytes = "1.0"
futures-util = "0.3"
http = "0.2"
http-body = "0.4.2"
http-body = "0.4.3"
hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] }
pin-project-lite = "0.2.7"
regex = "1.5"

View file

@ -13,6 +13,16 @@ impl Error {
inner: error.into(),
}
}
pub(crate) fn downcast<T>(self) -> Result<T, Self>
where
T: StdError + 'static,
{
match self.inner.downcast::<T>() {
Ok(t) => Ok(*t),
Err(err) => Err(*err.downcast().unwrap()),
}
}
}
impl fmt::Display for Error {

View file

@ -152,7 +152,7 @@ where
impl<S, E, ReqBody, ResBody> Service<Request<ReqBody>> for ExtractorMiddleware<S, E>
where
E: FromRequest<ReqBody> + 'static,
ReqBody: Send + 'static,
ReqBody: Default + Send + 'static,
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError>,
@ -212,6 +212,7 @@ impl<ReqBody, S, E, ResBody> Future for ResponseFuture<ReqBody, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ReqBody: Default,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError>,
{
@ -223,12 +224,13 @@ where
let new_state = match this.state.as_mut().project() {
StateProj::Extracting { future } => {
let (mut req, extracted) = ready!(future.as_mut().poll(cx));
let (req, extracted) = ready!(future.as_mut().poll(cx));
match extracted {
Ok(_) => {
let mut svc = this.svc.take().expect("future polled after completion");
let future = svc.call(req.into_request());
let req = req.try_into_request().unwrap_or_default();
let future = svc.call(req);
State::Call { future }
}
Err(err) => {

View file

@ -244,7 +244,7 @@
//!
//! [`body::Body`]: crate::body::Body
use crate::response::IntoResponse;
use crate::{response::IntoResponse, Error};
use async_trait::async_trait;
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
use rejection::*;
@ -397,32 +397,47 @@ impl<B> RequestParts<B> {
}
}
#[allow(clippy::wrong_self_convention)]
pub(crate) fn into_request(&mut self) -> Request<B> {
// this method uses `Error` since we might make this method public one day and then
// `Error` is more flexible.
pub(crate) fn try_into_request(self) -> Result<Request<B>, Error> {
let Self {
method,
uri,
version,
headers,
extensions,
body,
mut headers,
mut extensions,
mut body,
} = self;
let mut req = Request::new(body.take().expect("body already extracted"));
let mut req = if let Some(body) = body.take() {
Request::new(body)
} else {
return Err(Error::new(RequestAlreadyExtracted::BodyAlreadyExtracted(
BodyAlreadyExtracted,
)));
};
*req.method_mut() = method.clone();
*req.uri_mut() = uri.clone();
*req.version_mut() = *version;
*req.method_mut() = method;
*req.uri_mut() = uri;
*req.version_mut() = version;
if let Some(headers) = headers.take() {
*req.headers_mut() = headers;
} else {
return Err(Error::new(
RequestAlreadyExtracted::HeadersAlreadyExtracted(HeadersAlreadyExtracted),
));
}
if let Some(extensions) = extensions.take() {
*req.extensions_mut() = extensions;
} else {
return Err(Error::new(
RequestAlreadyExtracted::ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted),
));
}
req
Ok(req)
}
/// Gets a reference the request method.

View file

@ -13,14 +13,15 @@ use tower::BoxError;
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Extensions taken by other extractor"]
/// Rejection used if the method has been taken by another extractor.
/// Rejection used if the request extension has been taken by another
/// extractor.
pub struct ExtensionsAlreadyExtracted;
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Headers taken by other extractor"]
/// Rejection used if the URI has been taken by another extractor.
/// Rejection used if the headers has been taken by another extractor.
pub struct HeadersAlreadyExtracted;
}
@ -94,13 +95,6 @@ define_rejection! {
pub struct BodyAlreadyExtracted;
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Cannot have two `Request<_>` extractors for a single handler"]
/// Rejection type used if you try and extract the request more than once.
pub struct RequestAlreadyExtracted;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Form requests must have `Content-Type: x-www-form-urlencoded`"]
@ -272,6 +266,19 @@ composite_rejection! {
}
}
composite_rejection! {
/// Rejection used for [`Request<_>`].
///
/// Contains one variant for each way the [`Request<_>`] extractor can fail.
///
/// [`Request<_>`]: http::Request
pub enum RequestAlreadyExtracted {
BodyAlreadyExtracted,
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}
/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit).
///
/// Contains one variant for each way the

View file

@ -18,21 +18,29 @@ where
type Rejection = RequestAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let RequestParts {
method: _,
uri: _,
version: _,
headers,
extensions,
body,
} = req;
let req = std::mem::replace(
req,
RequestParts {
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
headers: None,
extensions: None,
body: None,
},
);
let all_parts = extensions.as_ref().zip(body.as_ref()).zip(headers.as_ref());
let err = match req.try_into_request() {
Ok(req) => return Ok(req),
Err(err) => err,
};
if all_parts.is_some() {
Ok(req.into_request())
} else {
Err(RequestAlreadyExtracted)
match err.downcast::<RequestAlreadyExtracted>() {
Ok(err) => return Err(err),
Err(err) => unreachable!(
"Unexpected error type from `try_into_request`: `{:?}`. This is a bug in axum, please file an issue",
err,
),
}
}
}
@ -251,3 +259,33 @@ where
Ok(string)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{body::Body, prelude::*, tests::*};
use http::StatusCode;
#[tokio::test]
async fn multiple_request_extractors() {
async fn handler(_: Request<Body>, _: Request<Body>) {}
let app = route("/", post(handler));
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.post(format!("http://{}", addr))
.body("hi there")
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
res.text().await.unwrap(),
"Cannot have two request body extractors for a single handler"
);
}
}

View file

@ -605,7 +605,7 @@ async fn wrong_method_service() {
}
/// Run a `tower::Service` in the background and get a URI for it.
async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
pub(crate) async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
where
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
ResBody: http_body::Body + Send + 'static,