From 9a6bc4e962aada341bab5afd2ca4fdd3823d90d5 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 3 Aug 2021 21:55:48 +0200 Subject: [PATCH] Break up `extract.rs` (#103) This breaks up `extract.rs` into several smaller submodules. The public API remains the same. This is done in prep for adding more tests to extractors which would get messy if they were all in the same file. --- src/extract/content_length_limit.rs | 69 +++ src/extract/extension.rs | 71 +++ src/extract/form.rs | 85 +++ src/extract/json.rs | 76 +++ src/extract/mod.rs | 898 ++-------------------------- src/extract/query.rs | 67 +++ src/extract/raw_query.rs | 39 ++ src/extract/request_parts.rs | 215 +++++++ src/extract/tuple.rs | 45 ++ src/extract/typed_header.rs | 61 ++ src/extract/url_params.rs | 97 +++ src/extract/url_params_map.rs | 66 ++ 12 files changed, 926 insertions(+), 863 deletions(-) create mode 100644 src/extract/content_length_limit.rs create mode 100644 src/extract/extension.rs create mode 100644 src/extract/form.rs create mode 100644 src/extract/json.rs create mode 100644 src/extract/query.rs create mode 100644 src/extract/raw_query.rs create mode 100644 src/extract/request_parts.rs create mode 100644 src/extract/tuple.rs create mode 100644 src/extract/typed_header.rs create mode 100644 src/extract/url_params.rs create mode 100644 src/extract/url_params_map.rs diff --git a/src/extract/content_length_limit.rs b/src/extract/content_length_limit.rs new file mode 100644 index 00000000..070dd889 --- /dev/null +++ b/src/extract/content_length_limit.rs @@ -0,0 +1,69 @@ +use super::{rejection::*, FromRequest, RequestParts}; +use async_trait::async_trait; +use std::ops::Deref; + +/// Extractor that will reject requests with a body larger than some size. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::prelude::*; +/// +/// async fn handler(body: extract::ContentLengthLimit) { +/// // ... +/// } +/// +/// let app = route("/", post(handler)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// This requires the request to have a `Content-Length` header. +#[derive(Debug, Clone)] +pub struct ContentLengthLimit(pub T); + +#[async_trait] +impl FromRequest for ContentLengthLimit +where + T: FromRequest, + B: Send, +{ + type Rejection = ContentLengthLimitRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let content_length = req + .headers() + .ok_or(ContentLengthLimitRejection::HeadersAlreadyExtracted( + HeadersAlreadyExtracted, + ))? + .get(http::header::CONTENT_LENGTH); + + let content_length = + content_length.and_then(|value| value.to_str().ok()?.parse::().ok()); + + if let Some(length) = content_length { + if length > N { + return Err(ContentLengthLimitRejection::PayloadTooLarge( + PayloadTooLarge, + )); + } + } else { + return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); + }; + + let value = T::from_request(req) + .await + .map_err(ContentLengthLimitRejection::Inner)?; + + Ok(Self(value)) + } +} + +impl Deref for ContentLengthLimit { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/extract/extension.rs b/src/extract/extension.rs new file mode 100644 index 00000000..47cb3d70 --- /dev/null +++ b/src/extract/extension.rs @@ -0,0 +1,71 @@ +use super::{rejection::*, FromRequest, RequestParts}; +use async_trait::async_trait; +use std::ops::Deref; + +/// Extractor that gets a value from request extensions. +/// +/// This is commonly used to share state across handlers. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::{AddExtensionLayer, prelude::*}; +/// use std::sync::Arc; +/// +/// // Some shared state used throughout our application +/// struct State { +/// // ... +/// } +/// +/// async fn handler(state: extract::Extension>) { +/// // ... +/// } +/// +/// let state = Arc::new(State { /* ... */ }); +/// +/// let app = route("/", get(handler)) +/// // Add middleware that inserts the state into all incoming request's +/// // extensions. +/// .layer(AddExtensionLayer::new(state)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// If the extension is missing it will reject the request with a `500 Interal +/// Server Error` response. +#[derive(Debug, Clone, Copy)] +pub struct Extension(pub T); + +#[async_trait] +impl FromRequest for Extension +where + T: Clone + Send + Sync + 'static, + B: Send, +{ + type Rejection = ExtensionRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let value = req + .extensions() + .ok_or(ExtensionsAlreadyExtracted)? + .get::() + .ok_or_else(|| { + MissingExtension::from_err(format!( + "Extension of type `{}` was not found. Perhaps you forgot to add it?", + std::any::type_name::() + )) + }) + .map(|x| x.clone())?; + + Ok(Extension(value)) + } +} + +impl Deref for Extension { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/extract/form.rs b/src/extract/form.rs new file mode 100644 index 00000000..deaaea51 --- /dev/null +++ b/src/extract/form.rs @@ -0,0 +1,85 @@ +use super::{has_content_type, rejection::*, take_body, FromRequest, RequestParts}; +use async_trait::async_trait; +use bytes::Buf; +use http::Method; +use serde::de::DeserializeOwned; +use std::ops::Deref; + +/// Extractor that deserializes `application/x-www-form-urlencoded` requests +/// into some type. +/// +/// `T` is expected to implement [`serde::Deserialize`]. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::prelude::*; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct SignUp { +/// username: String, +/// password: String, +/// } +/// +/// async fn accept_form(form: extract::Form) { +/// let sign_up: SignUp = form.0; +/// +/// // ... +/// } +/// +/// let app = route("/sign_up", post(accept_form)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// Note that `Content-Type: multipart/form-data` requests are not supported. +#[derive(Debug, Clone, Copy, Default)] +pub struct Form(pub T); + +#[async_trait] +impl FromRequest for Form +where + T: DeserializeOwned, + B: http_body::Body + Send, + B::Data: Send, + B::Error: Into, +{ + type Rejection = FormRejection; + + #[allow(warnings)] + async fn from_request(req: &mut RequestParts) -> Result { + if !has_content_type(&req, "application/x-www-form-urlencoded")? { + Err(InvalidFormContentType)?; + } + + if req.method().ok_or(MethodAlreadyExtracted)? == Method::GET { + let query = req + .uri() + .ok_or(UriAlreadyExtracted)? + .query() + .ok_or(QueryStringMissing)?; + let value = serde_urlencoded::from_str(query) + .map_err(FailedToDeserializeQueryString::new::)?; + Ok(Form(value)) + } else { + let body = take_body(req)?; + let chunks = hyper::body::aggregate(body) + .await + .map_err(FailedToBufferBody::from_err)?; + let value = serde_urlencoded::from_reader(chunks.reader()) + .map_err(FailedToDeserializeQueryString::new::)?; + + Ok(Form(value)) + } + } +} + +impl Deref for Form { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/extract/json.rs b/src/extract/json.rs new file mode 100644 index 00000000..e39b62e0 --- /dev/null +++ b/src/extract/json.rs @@ -0,0 +1,76 @@ +use super::{has_content_type, rejection::*, take_body, FromRequest, RequestParts}; +use async_trait::async_trait; +use serde::de::DeserializeOwned; +use std::ops::Deref; + +/// Extractor that deserializes request bodies into some type. +/// +/// `T` is expected to implement [`serde::Deserialize`]. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::prelude::*; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct CreateUser { +/// email: String, +/// password: String, +/// } +/// +/// async fn create_user(payload: extract::Json) { +/// let payload: CreateUser = payload.0; +/// +/// // ... +/// } +/// +/// let app = route("/users", post(create_user)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// If the query string cannot be parsed it will reject the request with a `400 +/// Bad Request` response. +/// +/// The request is required to have a `Content-Type: application/json` header. +#[derive(Debug, Clone, Copy, Default)] +pub struct Json(pub T); + +#[async_trait] +impl FromRequest for Json +where + T: DeserializeOwned, + B: http_body::Body + Send, + B::Data: Send, + B::Error: Into, +{ + type Rejection = JsonRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + use bytes::Buf; + + if has_content_type(req, "application/json")? { + let body = take_body(req)?; + + let buf = hyper::body::aggregate(body) + .await + .map_err(InvalidJsonBody::from_err)?; + + let value = serde_json::from_reader(buf.reader()).map_err(InvalidJsonBody::from_err)?; + + Ok(Json(value)) + } else { + Err(MissingJsonContentType.into()) + } + } +} + +impl Deref for Json { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 0f286f2c..f612c622 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -244,31 +244,41 @@ //! //! [`body::Body`]: crate::body::Body -use crate::{response::IntoResponse, util::ByteStr}; +use crate::response::IntoResponse; use async_trait::async_trait; -use bytes::{Buf, Bytes}; -use futures_util::stream::Stream; -use http::{header, Extensions, HeaderMap, Method, Request, Response, Uri, Version}; +use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version}; use rejection::*; -use serde::de::DeserializeOwned; -use std::{ - collections::HashMap, - convert::Infallible, - ops::Deref, - pin::Pin, - str::FromStr, - task::{Context, Poll}, -}; +use std::convert::Infallible; pub mod connect_info; pub mod extractor_middleware; pub mod rejection; -#[doc(inline)] -pub use self::extractor_middleware::extractor_middleware; +mod content_length_limit; +mod extension; +mod form; +mod json; +mod query; +mod raw_query; +mod request_parts; +mod tuple; +mod url_params; +mod url_params_map; #[doc(inline)] -pub use self::connect_info::ConnectInfo; +pub use self::{ + connect_info::ConnectInfo, + content_length_limit::ContentLengthLimit, + extension::Extension, + extractor_middleware::extractor_middleware, + form::Form, + json::Json, + query::Query, + raw_query::RawQuery, + request_parts::{Body, BodyStream}, + url_params::UrlParams, + url_params_map::UrlParamsMap, +}; #[cfg(feature = "multipart")] #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] @@ -279,6 +289,15 @@ pub mod multipart; #[doc(inline)] pub use self::multipart::Multipart; +#[cfg(feature = "headers")] +#[cfg_attr(docsrs, doc(cfg(feature = "headers")))] +mod typed_header; + +#[cfg(feature = "headers")] +#[cfg_attr(docsrs, doc(cfg(feature = "headers")))] +#[doc(inline)] +pub use self::typed_header::TypedHeader; + /// Types that can be created from requests. /// /// See the [module docs](crate::extract) for more details. @@ -480,46 +499,6 @@ impl RequestParts { } } -#[async_trait] -impl FromRequest for () -where - B: Send, -{ - type Rejection = Infallible; - - async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { - Ok(()) - } -} - -macro_rules! impl_from_request { - () => { - }; - - ( $head:ident, $($tail:ident),* $(,)? ) => { - #[async_trait] - #[allow(non_snake_case)] - impl FromRequest for ($head, $($tail,)*) - where - $head: FromRequest + Send, - $( $tail: FromRequest + Send, )* - B: Send, - { - type Rejection = Response; - - async fn from_request(req: &mut RequestParts) -> Result { - let $head = $head::from_request(req).await.map_err(IntoResponse::into_response)?; - $( let $tail = $tail::from_request(req).await.map_err(IntoResponse::into_response)?; )* - Ok(($head, $($tail,)*)) - } - } - - impl_from_request!($($tail,)*); - }; -} - -impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); - #[async_trait] impl FromRequest for Option where @@ -546,220 +525,6 @@ where } } -/// Extractor that deserializes query strings into some type. -/// -/// `T` is expected to implement [`serde::Deserialize`]. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::prelude::*; -/// use serde::Deserialize; -/// -/// #[derive(Deserialize)] -/// struct Pagination { -/// page: usize, -/// per_page: usize, -/// } -/// -/// // This will parse query strings like `?page=2&per_page=30` into `Pagination` -/// // structs. -/// async fn list_things(pagination: extract::Query) { -/// let pagination: Pagination = pagination.0; -/// -/// // ... -/// } -/// -/// let app = route("/list_things", get(list_things)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -/// -/// If the query string cannot be parsed it will reject the request with a `400 -/// Bad Request` response. -#[derive(Debug, Clone, Copy, Default)] -pub struct Query(pub T); - -#[async_trait] -impl FromRequest for Query -where - T: DeserializeOwned, - B: Send, -{ - type Rejection = QueryRejection; - - async fn from_request(req: &mut RequestParts) -> Result { - let query = req - .uri() - .ok_or(UriAlreadyExtracted)? - .query() - .ok_or(QueryStringMissing)?; - let value = serde_urlencoded::from_str(query) - .map_err(FailedToDeserializeQueryString::new::)?; - Ok(Query(value)) - } -} - -impl Deref for Query { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// Extractor that deserializes `application/x-www-form-urlencoded` requests -/// into some type. -/// -/// `T` is expected to implement [`serde::Deserialize`]. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::prelude::*; -/// use serde::Deserialize; -/// -/// #[derive(Deserialize)] -/// struct SignUp { -/// username: String, -/// password: String, -/// } -/// -/// async fn accept_form(form: extract::Form) { -/// let sign_up: SignUp = form.0; -/// -/// // ... -/// } -/// -/// let app = route("/sign_up", post(accept_form)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -/// -/// Note that `Content-Type: multipart/form-data` requests are not supported. -#[derive(Debug, Clone, Copy, Default)] -pub struct Form(pub T); - -#[async_trait] -impl FromRequest for Form -where - T: DeserializeOwned, - B: http_body::Body + Send, - B::Data: Send, - B::Error: Into, -{ - type Rejection = FormRejection; - - #[allow(warnings)] - async fn from_request(req: &mut RequestParts) -> Result { - if !has_content_type(&req, "application/x-www-form-urlencoded")? { - Err(InvalidFormContentType)?; - } - - if req.method().ok_or(MethodAlreadyExtracted)? == Method::GET { - let query = req - .uri() - .ok_or(UriAlreadyExtracted)? - .query() - .ok_or(QueryStringMissing)?; - let value = serde_urlencoded::from_str(query) - .map_err(FailedToDeserializeQueryString::new::)?; - Ok(Form(value)) - } else { - let body = take_body(req)?; - let chunks = hyper::body::aggregate(body) - .await - .map_err(FailedToBufferBody::from_err)?; - let value = serde_urlencoded::from_reader(chunks.reader()) - .map_err(FailedToDeserializeQueryString::new::)?; - - Ok(Form(value)) - } - } -} - -impl Deref for Form { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// Extractor that deserializes request bodies into some type. -/// -/// `T` is expected to implement [`serde::Deserialize`]. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::prelude::*; -/// use serde::Deserialize; -/// -/// #[derive(Deserialize)] -/// struct CreateUser { -/// email: String, -/// password: String, -/// } -/// -/// async fn create_user(payload: extract::Json) { -/// let payload: CreateUser = payload.0; -/// -/// // ... -/// } -/// -/// let app = route("/users", post(create_user)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -/// -/// If the query string cannot be parsed it will reject the request with a `400 -/// Bad Request` response. -/// -/// The request is required to have a `Content-Type: application/json` header. -#[derive(Debug, Clone, Copy, Default)] -pub struct Json(pub T); - -#[async_trait] -impl FromRequest for Json -where - T: DeserializeOwned, - B: http_body::Body + Send, - B::Data: Send, - B::Error: Into, -{ - type Rejection = JsonRejection; - - async fn from_request(req: &mut RequestParts) -> Result { - use bytes::Buf; - - if has_content_type(req, "application/json")? { - let body = take_body(req)?; - - let buf = hyper::body::aggregate(body) - .await - .map_err(InvalidJsonBody::from_err)?; - - let value = serde_json::from_reader(buf.reader()).map_err(InvalidJsonBody::from_err)?; - - Ok(Json(value)) - } else { - Err(MissingJsonContentType.into()) - } - } -} - -impl Deref for Json { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - fn has_content_type( req: &RequestParts, expected_content_type: &str, @@ -783,599 +548,6 @@ fn has_content_type( Ok(content_type.starts_with(expected_content_type)) } -/// Extractor that gets a value from request extensions. -/// -/// This is commonly used to share state across handlers. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::{AddExtensionLayer, prelude::*}; -/// use std::sync::Arc; -/// -/// // Some shared state used throughout our application -/// struct State { -/// // ... -/// } -/// -/// async fn handler(state: extract::Extension>) { -/// // ... -/// } -/// -/// let state = Arc::new(State { /* ... */ }); -/// -/// let app = route("/", get(handler)) -/// // Add middleware that inserts the state into all incoming request's -/// // extensions. -/// .layer(AddExtensionLayer::new(state)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -/// -/// If the extension is missing it will reject the request with a `500 Interal -/// Server Error` response. -#[derive(Debug, Clone, Copy)] -pub struct Extension(pub T); - -#[async_trait] -impl FromRequest for Extension -where - T: Clone + Send + Sync + 'static, - B: Send, -{ - type Rejection = ExtensionRejection; - - async fn from_request(req: &mut RequestParts) -> Result { - let value = req - .extensions() - .ok_or(ExtensionsAlreadyExtracted)? - .get::() - .ok_or_else(|| { - MissingExtension::from_err(format!( - "Extension of type `{}` was not found. Perhaps you forgot to add it?", - std::any::type_name::() - )) - }) - .map(|x| x.clone())?; - - Ok(Extension(value)) - } -} - -impl Deref for Extension { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[async_trait] -impl FromRequest for Bytes -where - B: http_body::Body + Send, - B::Data: Send, - B::Error: Into, -{ - type Rejection = BytesRejection; - - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; - - let bytes = hyper::body::to_bytes(body) - .await - .map_err(FailedToBufferBody::from_err)?; - - Ok(bytes) - } -} - -#[async_trait] -impl FromRequest for String -where - B: http_body::Body + Send, - B::Data: Send, - B::Error: Into, -{ - type Rejection = StringRejection; - - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; - - let bytes = hyper::body::to_bytes(body) - .await - .map_err(FailedToBufferBody::from_err)? - .to_vec(); - - let string = String::from_utf8(bytes).map_err(InvalidUtf8::from_err)?; - - Ok(string) - } -} - -/// Extractor that extracts the request body as a [`Stream`]. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::prelude::*; -/// use futures::StreamExt; -/// -/// async fn handler(mut stream: extract::BodyStream) { -/// while let Some(chunk) = stream.next().await { -/// // ... -/// } -/// } -/// -/// let app = route("/users", get(handler)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -/// -/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html -#[derive(Debug)] -pub struct BodyStream(B); - -impl Stream for BodyStream -where - B: http_body::Body + Unpin, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_data(cx) - } -} - -#[async_trait] -impl FromRequest for BodyStream -where - B: http_body::Body + Unpin + Send, -{ - type Rejection = BodyAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; - let stream = BodyStream(body); - Ok(stream) - } -} - -/// Extractor that extracts the request body. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::prelude::*; -/// use futures::StreamExt; -/// -/// async fn handler(extract::Body(body): extract::Body) { -/// // ... -/// } -/// -/// let app = route("/users", get(handler)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -#[derive(Debug, Default, Clone)] -pub struct Body(pub B); - -#[async_trait] -impl FromRequest for Body -where - B: Send, -{ - type Rejection = BodyAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; - Ok(Self(body)) - } -} - -#[async_trait] -impl FromRequest for Request -where - B: Send, -{ - type Rejection = RequestAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - let RequestParts { - method, - uri, - version, - headers, - extensions, - body, - } = req; - - let all_parts = method - .as_ref() - .zip(version.as_ref()) - .zip(uri.as_ref()) - .zip(extensions.as_ref()) - .zip(body.as_ref()) - .zip(headers.as_ref()); - - if all_parts.is_some() { - Ok(req.into_request()) - } else { - Err(RequestAlreadyExtracted) - } - } -} - -#[async_trait] -impl FromRequest for Method -where - B: Send, -{ - type Rejection = MethodAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - req.take_method().ok_or(MethodAlreadyExtracted) - } -} - -#[async_trait] -impl FromRequest for Uri -where - B: Send, -{ - type Rejection = UriAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - req.take_uri().ok_or(UriAlreadyExtracted) - } -} - -#[async_trait] -impl FromRequest for Version -where - B: Send, -{ - type Rejection = VersionAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - req.take_version().ok_or(VersionAlreadyExtracted) - } -} - -#[async_trait] -impl FromRequest for HeaderMap -where - B: Send, -{ - type Rejection = HeadersAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - req.take_headers().ok_or(HeadersAlreadyExtracted) - } -} - -/// Extractor that will reject requests with a body larger than some size. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::prelude::*; -/// -/// async fn handler(body: extract::ContentLengthLimit) { -/// // ... -/// } -/// -/// let app = route("/", post(handler)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -/// -/// This requires the request to have a `Content-Length` header. -#[derive(Debug, Clone)] -pub struct ContentLengthLimit(pub T); - -#[async_trait] -impl FromRequest for ContentLengthLimit -where - T: FromRequest, - B: Send, -{ - type Rejection = ContentLengthLimitRejection; - - async fn from_request(req: &mut RequestParts) -> Result { - let content_length = req - .headers() - .ok_or(ContentLengthLimitRejection::HeadersAlreadyExtracted( - HeadersAlreadyExtracted, - ))? - .get(http::header::CONTENT_LENGTH); - - let content_length = - content_length.and_then(|value| value.to_str().ok()?.parse::().ok()); - - if let Some(length) = content_length { - if length > N { - return Err(ContentLengthLimitRejection::PayloadTooLarge( - PayloadTooLarge, - )); - } - } else { - return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); - }; - - let value = T::from_request(req) - .await - .map_err(ContentLengthLimitRejection::Inner)?; - - Ok(Self(value)) - } -} - -impl Deref for ContentLengthLimit { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// Extractor that will get captures from the URL. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::prelude::*; -/// -/// async fn users_show(params: extract::UrlParamsMap) { -/// let id: Option<&str> = params.get("id"); -/// -/// // ... -/// } -/// -/// let app = route("/users/:id", get(users_show)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -/// -/// Note that you can only have one URL params extractor per handler. If you -/// have multiple it'll response with `500 Internal Server Error`. -#[derive(Debug)] -pub struct UrlParamsMap(HashMap); - -impl UrlParamsMap { - /// Look up the value for a key. - pub fn get(&self, key: &str) -> Option<&str> { - self.0.get(&ByteStr::new(key)).map(|s| s.as_str()) - } - - /// Look up the value for a key and parse it into a value of type `T`. - pub fn get_typed(&self, key: &str) -> Option> - where - T: FromStr, - { - self.get(key).map(str::parse) - } -} - -#[async_trait] -impl FromRequest for UrlParamsMap -where - B: Send, -{ - type Rejection = MissingRouteParams; - - async fn from_request(req: &mut RequestParts) -> Result { - if let Some(params) = req - .extensions_mut() - .and_then(|ext| ext.get_mut::>()) - { - if let Some(params) = params { - Ok(Self(params.0.iter().cloned().collect())) - } else { - Ok(Self(Default::default())) - } - } else { - Err(MissingRouteParams) - } - } -} - -/// Extractor that will get captures from the URL and parse them. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::{extract::UrlParams, prelude::*}; -/// use uuid::Uuid; -/// -/// async fn users_teams_show( -/// UrlParams(params): UrlParams<(Uuid, Uuid)>, -/// ) { -/// let user_id: Uuid = params.0; -/// let team_id: Uuid = params.1; -/// -/// // ... -/// } -/// -/// let app = route("/users/:user_id/team/:team_id", get(users_teams_show)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -/// -/// Note that you can only have one URL params extractor per handler. If you -/// have multiple it'll response with `500 Internal Server Error`. -#[derive(Debug)] -pub struct UrlParams(pub T); - -macro_rules! impl_parse_url { - () => {}; - - ( $head:ident, $($tail:ident),* $(,)? ) => { - #[async_trait] - impl FromRequest for UrlParams<($head, $($tail,)*)> - where - $head: FromStr + Send, - $( $tail: FromStr + Send, )* - B: Send, - { - type Rejection = UrlParamsRejection; - - #[allow(non_snake_case)] - async fn from_request(req: &mut RequestParts) -> Result { - let params = if let Some(params) = req - .extensions_mut() - .and_then(|ext| { - ext.get_mut::>() - }) - { - if let Some(params) = params { - params.0.clone() - } else { - Default::default() - } - } else { - return Err(MissingRouteParams.into()) - }; - - if let [(_, $head), $((_, $tail),)*] = &*params { - let $head = if let Ok(x) = $head.as_str().parse::<$head>() { - x - } else { - return Err(InvalidUrlParam::new::<$head>().into()); - }; - - $( - let $tail = if let Ok(x) = $tail.as_str().parse::<$tail>() { - x - } else { - return Err(InvalidUrlParam::new::<$tail>().into()); - }; - )* - - Ok(UrlParams(($head, $($tail,)*))) - } else { - Err(MissingRouteParams.into()) - } - } - } - - impl_parse_url!($($tail,)*); - }; -} - -impl_parse_url!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); - -impl Deref for UrlParams { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - fn take_body(req: &mut RequestParts) -> Result { req.take_body().ok_or(BodyAlreadyExtracted) } - -/// Extractor that extracts a typed header value from [`headers`]. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::{extract::TypedHeader, prelude::*}; -/// use headers::UserAgent; -/// -/// async fn users_teams_show( -/// TypedHeader(user_agent): TypedHeader, -/// ) { -/// // ... -/// } -/// -/// let app = route("/users/:user_id/team/:team_id", get(users_teams_show)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -#[cfg(feature = "headers")] -#[cfg_attr(docsrs, doc(cfg(feature = "headers")))] -#[derive(Debug, Clone, Copy)] -pub struct TypedHeader(pub T); - -#[cfg(feature = "headers")] -#[cfg_attr(docsrs, doc(cfg(feature = "headers")))] -#[async_trait] -impl FromRequest for TypedHeader -where - T: headers::Header, - B: Send, -{ - type Rejection = TypedHeaderRejection; - - async fn from_request(req: &mut RequestParts) -> Result { - let empty_headers = HeaderMap::new(); - let header_values = if let Some(headers) = req.headers() { - headers.get_all(T::name()) - } else { - empty_headers.get_all(T::name()) - }; - - T::decode(&mut header_values.iter()) - .map(Self) - .map_err(|err| rejection::TypedHeaderRejection { - err, - name: T::name(), - }) - } -} - -#[cfg(feature = "headers")] -#[cfg_attr(docsrs, doc(cfg(feature = "headers")))] -impl Deref for TypedHeader { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// Extractor that extracts the raw query string, without parsing it. -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::prelude::*; -/// use futures::StreamExt; -/// -/// async fn handler(extract::RawQuery(query): extract::RawQuery) { -/// // ... -/// } -/// -/// let app = route("/users", get(handler)); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -#[derive(Debug)] -pub struct RawQuery(pub Option); - -#[async_trait] -impl FromRequest for RawQuery -where - B: Send, -{ - type Rejection = Infallible; - - async fn from_request(req: &mut RequestParts) -> Result { - let query = req - .uri() - .and_then(|uri| uri.query()) - .map(|query| query.to_string()); - Ok(Self(query)) - } -} diff --git a/src/extract/query.rs b/src/extract/query.rs new file mode 100644 index 00000000..d10a0801 --- /dev/null +++ b/src/extract/query.rs @@ -0,0 +1,67 @@ +use super::{rejection::*, FromRequest, RequestParts}; +use async_trait::async_trait; +use serde::de::DeserializeOwned; +use std::ops::Deref; + +/// Extractor that deserializes query strings into some type. +/// +/// `T` is expected to implement [`serde::Deserialize`]. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::prelude::*; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Pagination { +/// page: usize, +/// per_page: usize, +/// } +/// +/// // This will parse query strings like `?page=2&per_page=30` into `Pagination` +/// // structs. +/// async fn list_things(pagination: extract::Query) { +/// let pagination: Pagination = pagination.0; +/// +/// // ... +/// } +/// +/// let app = route("/list_things", get(list_things)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// If the query string cannot be parsed it will reject the request with a `400 +/// Bad Request` response. +#[derive(Debug, Clone, Copy, Default)] +pub struct Query(pub T); + +#[async_trait] +impl FromRequest for Query +where + T: DeserializeOwned, + B: Send, +{ + type Rejection = QueryRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let query = req + .uri() + .ok_or(UriAlreadyExtracted)? + .query() + .ok_or(QueryStringMissing)?; + let value = serde_urlencoded::from_str(query) + .map_err(FailedToDeserializeQueryString::new::)?; + Ok(Query(value)) + } +} + +impl Deref for Query { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/extract/raw_query.rs b/src/extract/raw_query.rs new file mode 100644 index 00000000..402272d3 --- /dev/null +++ b/src/extract/raw_query.rs @@ -0,0 +1,39 @@ +use super::{FromRequest, RequestParts}; +use async_trait::async_trait; +use std::convert::Infallible; + +/// Extractor that extracts the raw query string, without parsing it. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::prelude::*; +/// use futures::StreamExt; +/// +/// async fn handler(extract::RawQuery(query): extract::RawQuery) { +/// // ... +/// } +/// +/// let app = route("/users", get(handler)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +#[derive(Debug)] +pub struct RawQuery(pub Option); + +#[async_trait] +impl FromRequest for RawQuery +where + B: Send, +{ + type Rejection = Infallible; + + async fn from_request(req: &mut RequestParts) -> Result { + let query = req + .uri() + .and_then(|uri| uri.query()) + .map(|query| query.to_string()); + Ok(Self(query)) + } +} diff --git a/src/extract/request_parts.rs b/src/extract/request_parts.rs new file mode 100644 index 00000000..168de267 --- /dev/null +++ b/src/extract/request_parts.rs @@ -0,0 +1,215 @@ +use super::{rejection::*, take_body, FromRequest, RequestParts}; +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::stream::Stream; +use http::{HeaderMap, Method, Request, Uri, Version}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +#[async_trait] +impl FromRequest for Request +where + B: Send, +{ + type Rejection = RequestAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + let RequestParts { + method, + uri, + version, + headers, + extensions, + body, + } = req; + + let all_parts = method + .as_ref() + .zip(version.as_ref()) + .zip(uri.as_ref()) + .zip(extensions.as_ref()) + .zip(body.as_ref()) + .zip(headers.as_ref()); + + if all_parts.is_some() { + Ok(req.into_request()) + } else { + Err(RequestAlreadyExtracted) + } + } +} + +#[async_trait] +impl FromRequest for Body +where + B: Send, +{ + type Rejection = BodyAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + let body = take_body(req)?; + Ok(Self(body)) + } +} + +#[async_trait] +impl FromRequest for Method +where + B: Send, +{ + type Rejection = MethodAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + req.take_method().ok_or(MethodAlreadyExtracted) + } +} + +#[async_trait] +impl FromRequest for Uri +where + B: Send, +{ + type Rejection = UriAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + req.take_uri().ok_or(UriAlreadyExtracted) + } +} + +#[async_trait] +impl FromRequest for Version +where + B: Send, +{ + type Rejection = VersionAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + req.take_version().ok_or(VersionAlreadyExtracted) + } +} + +#[async_trait] +impl FromRequest for HeaderMap +where + B: Send, +{ + type Rejection = HeadersAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + req.take_headers().ok_or(HeadersAlreadyExtracted) + } +} + +/// Extractor that extracts the request body as a [`Stream`]. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::prelude::*; +/// use futures::StreamExt; +/// +/// async fn handler(mut stream: extract::BodyStream) { +/// while let Some(chunk) = stream.next().await { +/// // ... +/// } +/// } +/// +/// let app = route("/users", get(handler)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html +#[derive(Debug)] +pub struct BodyStream(B); + +impl Stream for BodyStream +where + B: http_body::Body + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_data(cx) + } +} + +#[async_trait] +impl FromRequest for BodyStream +where + B: http_body::Body + Unpin + Send, +{ + type Rejection = BodyAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + let body = take_body(req)?; + let stream = BodyStream(body); + Ok(stream) + } +} + +/// Extractor that extracts the request body. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::prelude::*; +/// use futures::StreamExt; +/// +/// async fn handler(extract::Body(body): extract::Body) { +/// // ... +/// } +/// +/// let app = route("/users", get(handler)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +#[derive(Debug, Default, Clone)] +pub struct Body(pub B); + +#[async_trait] +impl FromRequest for Bytes +where + B: http_body::Body + Send, + B::Data: Send, + B::Error: Into, +{ + type Rejection = BytesRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let body = take_body(req)?; + + let bytes = hyper::body::to_bytes(body) + .await + .map_err(FailedToBufferBody::from_err)?; + + Ok(bytes) + } +} + +#[async_trait] +impl FromRequest for String +where + B: http_body::Body + Send, + B::Data: Send, + B::Error: Into, +{ + type Rejection = StringRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let body = take_body(req)?; + + let bytes = hyper::body::to_bytes(body) + .await + .map_err(FailedToBufferBody::from_err)? + .to_vec(); + + let string = String::from_utf8(bytes).map_err(InvalidUtf8::from_err)?; + + Ok(string) + } +} diff --git a/src/extract/tuple.rs b/src/extract/tuple.rs new file mode 100644 index 00000000..5d3b2bff --- /dev/null +++ b/src/extract/tuple.rs @@ -0,0 +1,45 @@ +use super::{FromRequest, RequestParts}; +use crate::response::IntoResponse; +use async_trait::async_trait; +use http::Response; +use std::convert::Infallible; + +#[async_trait] +impl FromRequest for () +where + B: Send, +{ + type Rejection = Infallible; + + async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { + Ok(()) + } +} + +macro_rules! impl_from_request { + () => { + }; + + ( $head:ident, $($tail:ident),* $(,)? ) => { + #[async_trait] + #[allow(non_snake_case)] + impl FromRequest for ($head, $($tail,)*) + where + $head: FromRequest + Send, + $( $tail: FromRequest + Send, )* + B: Send, + { + type Rejection = Response; + + async fn from_request(req: &mut RequestParts) -> Result { + let $head = $head::from_request(req).await.map_err(IntoResponse::into_response)?; + $( let $tail = $tail::from_request(req).await.map_err(IntoResponse::into_response)?; )* + Ok(($head, $($tail,)*)) + } + } + + impl_from_request!($($tail,)*); + }; +} + +impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); diff --git a/src/extract/typed_header.rs b/src/extract/typed_header.rs new file mode 100644 index 00000000..3307b339 --- /dev/null +++ b/src/extract/typed_header.rs @@ -0,0 +1,61 @@ +use super::{rejection::TypedHeaderRejection, take_body, FromRequest, RequestParts}; +use async_trait::async_trait; +use headers::HeaderMap; +use std::ops::Deref; + +/// Extractor that extracts a typed header value from [`headers`]. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::{extract::TypedHeader, prelude::*}; +/// use headers::UserAgent; +/// +/// async fn users_teams_show( +/// TypedHeader(user_agent): TypedHeader, +/// ) { +/// // ... +/// } +/// +/// let app = route("/users/:user_id/team/:team_id", get(users_teams_show)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +#[cfg(feature = "headers")] +#[cfg_attr(docsrs, doc(cfg(feature = "headers")))] +#[derive(Debug, Clone, Copy)] +pub struct TypedHeader(pub T); + +#[async_trait] +impl FromRequest for TypedHeader +where + T: headers::Header, + B: Send, +{ + type Rejection = TypedHeaderRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let empty_headers = HeaderMap::new(); + let header_values = if let Some(headers) = req.headers() { + headers.get_all(T::name()) + } else { + empty_headers.get_all(T::name()) + }; + + T::decode(&mut header_values.iter()) + .map(Self) + .map_err(|err| TypedHeaderRejection { + err, + name: T::name(), + }) + } +} + +impl Deref for TypedHeader { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/extract/url_params.rs b/src/extract/url_params.rs new file mode 100644 index 00000000..e06deb6f --- /dev/null +++ b/src/extract/url_params.rs @@ -0,0 +1,97 @@ +use super::{rejection::*, FromRequest, RequestParts}; +use async_trait::async_trait; +use std::{ops::Deref, str::FromStr}; + +/// Extractor that will get captures from the URL and parse them. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::{extract::UrlParams, prelude::*}; +/// use uuid::Uuid; +/// +/// async fn users_teams_show( +/// UrlParams(params): UrlParams<(Uuid, Uuid)>, +/// ) { +/// let user_id: Uuid = params.0; +/// let team_id: Uuid = params.1; +/// +/// // ... +/// } +/// +/// let app = route("/users/:user_id/team/:team_id", get(users_teams_show)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// Note that you can only have one URL params extractor per handler. If you +/// have multiple it'll response with `500 Internal Server Error`. +#[derive(Debug)] +pub struct UrlParams(pub T); + +macro_rules! impl_parse_url { + () => {}; + + ( $head:ident, $($tail:ident),* $(,)? ) => { + #[async_trait] + impl FromRequest for UrlParams<($head, $($tail,)*)> + where + $head: FromStr + Send, + $( $tail: FromStr + Send, )* + B: Send, + { + type Rejection = UrlParamsRejection; + + #[allow(non_snake_case)] + async fn from_request(req: &mut RequestParts) -> Result { + let params = if let Some(params) = req + .extensions_mut() + .and_then(|ext| { + ext.get_mut::>() + }) + { + if let Some(params) = params { + params.0.clone() + } else { + Default::default() + } + } else { + return Err(MissingRouteParams.into()) + }; + + if let [(_, $head), $((_, $tail),)*] = &*params { + let $head = if let Ok(x) = $head.as_str().parse::<$head>() { + x + } else { + return Err(InvalidUrlParam::new::<$head>().into()); + }; + + $( + let $tail = if let Ok(x) = $tail.as_str().parse::<$tail>() { + x + } else { + return Err(InvalidUrlParam::new::<$tail>().into()); + }; + )* + + Ok(UrlParams(($head, $($tail,)*))) + } else { + Err(MissingRouteParams.into()) + } + } + } + + impl_parse_url!($($tail,)*); + }; +} + +impl_parse_url!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); + +impl Deref for UrlParams { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/extract/url_params_map.rs b/src/extract/url_params_map.rs new file mode 100644 index 00000000..0673b5c9 --- /dev/null +++ b/src/extract/url_params_map.rs @@ -0,0 +1,66 @@ +use super::{rejection::*, FromRequest, RequestParts}; +use crate::util::ByteStr; +use async_trait::async_trait; +use std::{collections::HashMap, str::FromStr}; + +/// Extractor that will get captures from the URL. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::prelude::*; +/// +/// async fn users_show(params: extract::UrlParamsMap) { +/// let id: Option<&str> = params.get("id"); +/// +/// // ... +/// } +/// +/// let app = route("/users/:id", get(users_show)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// Note that you can only have one URL params extractor per handler. If you +/// have multiple it'll response with `500 Internal Server Error`. +#[derive(Debug)] +pub struct UrlParamsMap(HashMap); + +impl UrlParamsMap { + /// Look up the value for a key. + pub fn get(&self, key: &str) -> Option<&str> { + self.0.get(&ByteStr::new(key)).map(|s| s.as_str()) + } + + /// Look up the value for a key and parse it into a value of type `T`. + pub fn get_typed(&self, key: &str) -> Option> + where + T: FromStr, + { + self.get(key).map(str::parse) + } +} + +#[async_trait] +impl FromRequest for UrlParamsMap +where + B: Send, +{ + type Rejection = MissingRouteParams; + + async fn from_request(req: &mut RequestParts) -> Result { + if let Some(params) = req + .extensions_mut() + .and_then(|ext| ext.get_mut::>()) + { + if let Some(params) = params { + Ok(Self(params.0.iter().cloned().collect())) + } else { + Ok(Self(Default::default())) + } + } else { + Err(MissingRouteParams) + } + } +}