Clarify docs around body extractors

This commit is contained in:
David Pedersen 2021-07-23 00:26:08 +02:00
parent f25f7f90ff
commit d927c819d3
3 changed files with 105 additions and 48 deletions

View file

@ -172,6 +172,77 @@
//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! # Request body extractors
//!
//! Most of the time your request body type will be [`body::Body`] (a re-export
//! of [`hyper::Body`]), which is directly supported by all extractors.
//!
//! However if you're applying a tower middleware that changes the response you
//! might have to apply a different body type to some extractors:
//!
//! ```rust
//! use std::{
//! task::{Context, Poll},
//! pin::Pin,
//! };
//! use tower_http::map_request_body::MapRequestBodyLayer;
//! use axum::prelude::*;
//!
//! struct MyBody<B>(B);
//!
//! impl<B> http_body::Body for MyBody<B>
//! where
//! B: http_body::Body + Unpin,
//! {
//! type Data = B::Data;
//! type Error = B::Error;
//!
//! fn poll_data(
//! mut self: Pin<&mut Self>,
//! cx: &mut Context<'_>,
//! ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
//! Pin::new(&mut self.0).poll_data(cx)
//! }
//!
//! fn poll_trailers(
//! mut self: Pin<&mut Self>,
//! cx: &mut Context<'_>,
//! ) -> Poll<Result<Option<headers::HeaderMap>, Self::Error>> {
//! Pin::new(&mut self.0).poll_trailers(cx)
//! }
//! }
//!
//! let app =
//! // `String` works directly with any body type
//! route(
//! "/string",
//! get(|_: String| async {})
//! )
//! .route(
//! "/body",
//! // `extract::Body` defaults to `axum::body::Body`
//! // but can be customized
//! get(|_: extract::Body<MyBody<Body>>| async {})
//! )
//! .route(
//! "/body-stream",
//! // same for `extract::BodyStream`
//! get(|_: extract::BodyStream<MyBody<Body>>| async {}),
//! )
//! .route(
//! // and `Request<_>`
//! "/request",
//! get(|_: Request<MyBody<Body>>| async {})
//! )
//! // middleware that changes the request body type
//! .layer(MapRequestBodyLayer::new(MyBody));
//! # async {
//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! [`body::Body`]: crate::body::Body
use crate::{response::IntoResponse, util::ByteStr};
use async_trait::async_trait;
@ -788,6 +859,39 @@ where
}
}
/// Extractor that extracts the request body.
///
/// # Example
///
/// ```rust,no_run
/// use axum::prelude::*;
/// use futures::StreamExt;
///
/// async fn handler(extract::Body(body): extract::Body) {
/// // ...
/// }
///
/// let app = route("/users", get(handler));
/// # async {
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
#[derive(Debug, Default, Clone)]
pub struct Body<B = crate::body::Body>(pub B);
#[async_trait]
impl<B> FromRequest<B> for Body<B>
where
B: Send,
{
type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
Ok(Self(body))
}
}
#[async_trait]
impl<B> FromRequest<B> for Request<B>
where

View file

@ -487,7 +487,7 @@
//! "/",
//! service::any(service_fn(|_: Request<Body>| async {
//! let res = Response::new(Body::from("Hi from `GET /`"));
//! Ok::<_, Infallible>(res)
//! Ok(res)
//! }))
//! ).route(
//! // GET `/static/Cargo.toml` goes to a service from tower-http

View file

@ -612,53 +612,6 @@ async fn typed_header() {
assert_eq!(body, "invalid HTTP header (user-agent)");
}
#[tokio::test]
async fn different_request_body_types() {
use http_body::{Empty, Full};
use std::convert::Infallible;
use tower_http::map_request_body::MapRequestBodyLayer;
async fn handler(body: String) -> String {
body
}
async fn svc_handler<B>(req: Request<B>) -> Result<Response<Body>, Infallible>
where
B: http_body::Body,
B::Error: std::fmt::Debug,
{
let body = hyper::body::to_bytes(req.into_body()).await.unwrap();
Ok(Response::new(Body::from(body)))
}
let app = route("/", service::get(service_fn(svc_handler)))
.route(
"/foo",
get(handler.layer(MapRequestBodyLayer::new(|_| Full::<Bytes>::from("foo")))),
)
.layer(MapRequestBodyLayer::new(|_| Empty::<Bytes>::new()));
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/", addr))
.send()
.await
.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "");
let res = client
.get(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "foo");
}
#[tokio::test]
async fn service_in_bottom() {
async fn handler(_req: Request<hyper::Body>) -> Result<Response<hyper::Body>, hyper::Error> {