diff --git a/Cargo.toml b/Cargo.toml index c9d18a92..6b43ec9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "axum", + "axum-core", "axum-debug", "axum-extra", "examples/*", diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md new file mode 100644 index 00000000..16a28cbb --- /dev/null +++ b/axum-core/CHANGELOG.md @@ -0,0 +1,10 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +# Unreleased + +- None. diff --git a/axum-core/Cargo.toml b/axum-core/Cargo.toml new file mode 100644 index 00000000..7681c68f --- /dev/null +++ b/axum-core/Cargo.toml @@ -0,0 +1,24 @@ +[package] +categories = ["asynchronous", "network-programming", "web-programming"] +description = "Core types and traits for axum" +edition = "2018" +homepage = "https://github.com/tokio-rs/axum" +keywords = ["http", "web", "framework"] +license = "MIT" +name = "axum-core" +readme = "README.md" +repository = "https://github.com/tokio-rs/axum" +version = "0.1.0" + +[dependencies] +async-trait = "0.1" +bytes = "1.0" +futures-util = { version = "0.3", default-features = false, features = ["alloc"] } +http = "0.2" +http-body = "0.4" +mime = "0.3.16" + +[dev-dependencies] +futures-util = "0.3" +axum = { path = "../axum", version = "0.3" } +hyper = "0.14" diff --git a/axum-core/LICENSE b/axum-core/LICENSE new file mode 100644 index 00000000..538d04ab --- /dev/null +++ b/axum-core/LICENSE @@ -0,0 +1,7 @@ +Copyright 2021 Axum Contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/axum-core/README.md b/axum-core/README.md new file mode 100644 index 00000000..2506b283 --- /dev/null +++ b/axum-core/README.md @@ -0,0 +1,45 @@ +# axum-core + +[![Build status](https://github.com/tokio-rs/axum/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/tokio-rs/axum-core/actions/workflows/CI.yml) +[![Crates.io](https://img.shields.io/crates/v/axum-core)](https://crates.io/crates/axum-core) +[![Documentation](https://docs.rs/axum-core/badge.svg)](https://docs.rs/axum-core) + +Core types and traits for axum. + +More information about this crate can be found in the [crate documentation][docs]. + +## Safety + +This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust. + +## Minimum supported Rust version + +axum-core's MSRV is 1.54. + +## Getting Help + +You're also welcome to ask in the [Discord channel][chat] or open an [issue] +with your question. + +## Contributing + +:balloon: Thanks for your help improving the project! We are so happy to have +you! We have a [contributing guide][contributing] to help you get involved in the +`axum` project. + +## License + +This project is licensed under the [MIT license][license]. + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in `axum` by you, shall be licensed as MIT, without any +additional terms or conditions. + +[`axum`]: https://crates.io/crates/axum +[chat]: https://discord.gg/tokio +[contributing]: /CONTRIBUTING.md +[docs]: https://docs.rs/axum-core +[license]: /axum-core/LICENSE +[issue]: https://github.com/tokio-rs/axum/issues/new diff --git a/axum-core/src/body.rs b/axum-core/src/body.rs new file mode 100644 index 00000000..9f254089 --- /dev/null +++ b/axum-core/src/body.rs @@ -0,0 +1,92 @@ +//! HTTP body utilities. + +use crate::{BoxError, Error}; +use bytes::Bytes; +use bytes::{Buf, BufMut}; +use http_body::Body; + +/// A boxed [`Body`] trait object. +/// +/// This is used in axum as the response body type for applications. It's +/// necessary to unify multiple response bodies types into one. +pub type BoxBody = http_body::combinators::UnsyncBoxBody; + +/// Convert a [`http_body::Body`] into a [`BoxBody`]. +pub fn boxed(body: B) -> BoxBody +where + B: http_body::Body + Send + 'static, + B::Error: Into, +{ + try_downcast(body).unwrap_or_else(|body| body.map_err(Error::new).boxed_unsync()) +} + +pub(crate) fn try_downcast(k: K) -> Result +where + T: 'static, + K: Send + 'static, +{ + let mut k = Some(k); + if let Some(k) = ::downcast_mut::>(&mut k) { + Ok(k.take().unwrap()) + } else { + Err(k.unwrap()) + } +} + +// copied from hyper under the following license: +// Copyright (c) 2014-2021 Sean McArthur + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +pub(crate) async fn to_bytes(body: T) -> Result +where + T: Body, +{ + futures_util::pin_mut!(body); + + // If there's only 1 chunk, we can just return Buf::to_bytes() + let mut first = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(Bytes::new()); + }; + + let second = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(first.copy_to_bytes(first.remaining())); + }; + + // With more than 1 buf, we gotta flatten into a Vec first. + let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; + let mut vec = Vec::with_capacity(cap); + vec.put(first); + vec.put(second); + + while let Some(buf) = body.data().await { + vec.put(buf?); + } + + Ok(vec.into()) +} + +#[test] +fn test_try_downcast() { + assert_eq!(try_downcast::(5_u32), Err(5_u32)); + assert_eq!(try_downcast::(5_i32), Ok(5_i32)); +} diff --git a/axum/src/error.rs b/axum-core/src/error.rs similarity index 87% rename from axum/src/error.rs rename to axum-core/src/error.rs index cac48907..93e5295f 100644 --- a/axum/src/error.rs +++ b/axum-core/src/error.rs @@ -8,7 +8,8 @@ pub struct Error { } impl Error { - pub(crate) fn new(error: impl Into) -> Self { + /// Create a new `Error` from a boxable error. + pub fn new(error: impl Into) -> Self { Self { inner: error.into(), } diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs new file mode 100644 index 00000000..c3951a67 --- /dev/null +++ b/axum-core/src/extract/mod.rs @@ -0,0 +1,284 @@ +//! Types and traits for extracting data from requests. +//! +//! See [`axum::extract`] for more details. +//! +//! [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html + +use self::rejection::*; +use crate::response::IntoResponse; +use crate::Error; +use async_trait::async_trait; +use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; +use std::convert::Infallible; + +pub mod rejection; + +mod request_parts; +mod tuple; + +/// Types that can be created from requests. +/// +/// See [`axum::extract`] for more details. +/// +/// # What is the `B` type parameter? +/// +/// `FromRequest` is generic over the request body (the `B` in +/// [`http::Request`]). This is to allow `FromRequest` to be usable with any +/// type of request body. This is necessary because some middleware change the +/// request body, for example to add timeouts. +/// +/// If you're writing your own `FromRequest` that wont be used outside your +/// application, and not using any middleware that changes the request body, you +/// can most likely use `axum::body::Body`. +/// +/// If you're writing a library that's intended for others to use, it's recommended +/// to keep the generic type parameter: +/// +/// ```rust +/// use axum::{ +/// async_trait, +/// extract::{FromRequest, RequestParts}, +/// }; +/// +/// struct MyExtractor; +/// +/// #[async_trait] +/// impl FromRequest for MyExtractor +/// where +/// B: Send, // required by `async_trait` +/// { +/// type Rejection = http::StatusCode; +/// +/// async fn from_request(req: &mut RequestParts) -> Result { +/// // ... +/// # unimplemented!() +/// } +/// } +/// ``` +/// +/// This ensures your extractor is as flexible as possible. +/// +/// [`http::Request`]: http::Request +/// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html +#[async_trait] +pub trait FromRequest: Sized { + /// If the extractor fails it'll use this "rejection" type. A rejection is + /// a kind of error that can be converted into a response. + type Rejection: IntoResponse; + + /// Perform the extraction. + async fn from_request(req: &mut RequestParts) -> Result; +} + +/// The type used with [`FromRequest`] to extract data from requests. +/// +/// Has several convenience methods for getting owned parts of the request. +#[derive(Debug)] +pub struct RequestParts { + method: Method, + uri: Uri, + version: Version, + headers: Option, + extensions: Option, + body: Option, +} + +impl RequestParts { + /// Create a new `RequestParts`. + /// + /// You generally shouldn't need to construct this type yourself, unless + /// using extractors outside of axum for example to implement a + /// [`tower::Service`]. + /// + /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html + pub fn new(req: Request) -> Self { + let ( + http::request::Parts { + method, + uri, + version, + headers, + extensions, + .. + }, + body, + ) = req.into_parts(); + + RequestParts { + method, + uri, + version, + headers: Some(headers), + extensions: Some(extensions), + body: Some(body), + } + } + + /// Convert this `RequestParts` back into a [`Request`]. + /// + /// Fails if + /// + /// - The full [`HeaderMap`] has been extracted, that is [`take_headers`] + /// have been called. + /// - The full [`Extensions`] has been extracted, that is + /// [`take_extensions`] have been called. + /// - The request body has been extracted, that is [`take_body`] have been + /// called. + /// + /// [`take_headers`]: RequestParts::take_headers + /// [`take_extensions`]: RequestParts::take_extensions + /// [`take_body`]: RequestParts::take_body + pub fn try_into_request(self) -> Result, Error> { + let Self { + method, + uri, + version, + mut headers, + mut extensions, + mut body, + } = self; + + let mut req = if let Some(body) = body.take() { + Request::new(body) + } else { + return Err(Error::new(RequestAlreadyExtracted::BodyAlreadyExtracted( + BodyAlreadyExtracted, + ))); + }; + + *req.method_mut() = method; + *req.uri_mut() = uri; + *req.version_mut() = version; + + if let Some(headers) = headers.take() { + *req.headers_mut() = headers; + } else { + return Err(Error::new( + RequestAlreadyExtracted::HeadersAlreadyExtracted(HeadersAlreadyExtracted), + )); + } + + if let Some(extensions) = extensions.take() { + *req.extensions_mut() = extensions; + } else { + return Err(Error::new( + RequestAlreadyExtracted::ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted), + )); + } + + Ok(req) + } + + /// Gets a reference the request method. + pub fn method(&self) -> &Method { + &self.method + } + + /// Gets a mutable reference to the request method. + pub fn method_mut(&mut self) -> &mut Method { + &mut self.method + } + + /// Gets a reference the request URI. + pub fn uri(&self) -> &Uri { + &self.uri + } + + /// Gets a mutable reference to the request URI. + pub fn uri_mut(&mut self) -> &mut Uri { + &mut self.uri + } + + /// Get the request HTTP version. + pub fn version(&self) -> Version { + self.version + } + + /// Gets a mutable reference to the request HTTP version. + pub fn version_mut(&mut self) -> &mut Version { + &mut self.version + } + + /// Gets a reference to the request headers. + /// + /// Returns `None` if the headers has been taken by another extractor. + pub fn headers(&self) -> Option<&HeaderMap> { + self.headers.as_ref() + } + + /// Gets a mutable reference to the request headers. + /// + /// Returns `None` if the headers has been taken by another extractor. + pub fn headers_mut(&mut self) -> Option<&mut HeaderMap> { + self.headers.as_mut() + } + + /// Takes the headers out of the request, leaving a `None` in its place. + pub fn take_headers(&mut self) -> Option { + self.headers.take() + } + + /// Gets a reference to the request extensions. + /// + /// Returns `None` if the extensions has been taken by another extractor. + pub fn extensions(&self) -> Option<&Extensions> { + self.extensions.as_ref() + } + + /// Gets a mutable reference to the request extensions. + /// + /// Returns `None` if the extensions has been taken by another extractor. + pub fn extensions_mut(&mut self) -> Option<&mut Extensions> { + self.extensions.as_mut() + } + + /// Takes the extensions out of the request, leaving a `None` in its place. + pub fn take_extensions(&mut self) -> Option { + self.extensions.take() + } + + /// Gets a reference to the request body. + /// + /// Returns `None` if the body has been taken by another extractor. + pub fn body(&self) -> Option<&B> { + self.body.as_ref() + } + + /// Gets a mutable reference to the request body. + /// + /// Returns `None` if the body has been taken by another extractor. + pub fn body_mut(&mut self) -> Option<&mut B> { + self.body.as_mut() + } + + /// Takes the body out of the request, leaving a `None` in its place. + pub fn take_body(&mut self) -> Option { + self.body.take() + } +} + +#[async_trait] +impl FromRequest for Option +where + T: FromRequest, + B: Send, +{ + type Rejection = Infallible; + + async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { + Ok(T::from_request(req).await.ok()) + } +} + +#[async_trait] +impl FromRequest for Result +where + T: FromRequest, + B: Send, +{ + type Rejection = Infallible; + + async fn from_request(req: &mut RequestParts) -> Result { + Ok(T::from_request(req).await) + } +} diff --git a/axum-core/src/extract/rejection.rs b/axum-core/src/extract/rejection.rs new file mode 100644 index 00000000..ad4f0eeb --- /dev/null +++ b/axum-core/src/extract/rejection.rs @@ -0,0 +1,85 @@ +//! Rejection response types. + +define_rejection! { + #[status = INTERNAL_SERVER_ERROR] + #[body = "Cannot have two request body extractors for a single handler"] + /// Rejection type used if you try and extract the request body more than + /// once. + pub struct BodyAlreadyExtracted; +} + +define_rejection! { + #[status = INTERNAL_SERVER_ERROR] + #[body = "Headers taken by other extractor"] + /// Rejection used if the headers has been taken by another extractor. + pub struct HeadersAlreadyExtracted; +} + +define_rejection! { + #[status = INTERNAL_SERVER_ERROR] + #[body = "Extensions taken by other extractor"] + /// Rejection used if the request extension has been taken by another + /// extractor. + pub struct ExtensionsAlreadyExtracted; +} + +define_rejection! { + #[status = BAD_REQUEST] + #[body = "Failed to buffer the request body"] + /// Rejection type for extractors that buffer the request body. Used if the + /// request body cannot be buffered due to an error. + pub struct FailedToBufferBody(Error); +} + +define_rejection! { + #[status = BAD_REQUEST] + #[body = "Request body didn't contain valid UTF-8"] + /// Rejection type used when buffering the request into a [`String`] if the + /// body doesn't contain valid UTF-8. + pub struct InvalidUtf8(Error); +} + +composite_rejection! { + /// Rejection used for [`Request<_>`]. + /// + /// Contains one variant for each way the [`Request<_>`] extractor can fail. + /// + /// [`Request<_>`]: http::Request + pub enum RequestAlreadyExtracted { + BodyAlreadyExtracted, + HeadersAlreadyExtracted, + ExtensionsAlreadyExtracted, + } +} + +composite_rejection! { + /// Rejection used for [`Bytes`](bytes::Bytes). + /// + /// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor + /// can fail. + pub enum BytesRejection { + BodyAlreadyExtracted, + FailedToBufferBody, + } +} + +composite_rejection! { + /// Rejection used for [`String`]. + /// + /// Contains one variant for each way the [`String`] extractor can fail. + pub enum StringRejection { + BodyAlreadyExtracted, + FailedToBufferBody, + InvalidUtf8, + } +} + +composite_rejection! { + /// Rejection used for [`http::request::Parts`]. + /// + /// Contains one variant for each way the [`http::request::Parts`] extractor can fail. + pub enum RequestPartsAlreadyExtracted { + HeadersAlreadyExtracted, + ExtensionsAlreadyExtracted, + } +} diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs new file mode 100644 index 00000000..4752a8d7 --- /dev/null +++ b/axum-core/src/extract/request_parts.rs @@ -0,0 +1,182 @@ +use super::{rejection::*, FromRequest, RequestParts}; +use crate::BoxError; +use async_trait::async_trait; +use bytes::Bytes; +use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; +use std::convert::Infallible; + +#[async_trait] +impl FromRequest for Request +where + B: Send, +{ + type Rejection = RequestAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + let req = std::mem::replace( + req, + RequestParts { + method: req.method.clone(), + version: req.version, + uri: req.uri.clone(), + headers: None, + extensions: None, + body: None, + }, + ); + + let err = match req.try_into_request() { + Ok(req) => return Ok(req), + Err(err) => err, + }; + + match err.downcast::() { + Ok(err) => return Err(err), + Err(err) => unreachable!( + "Unexpected error type from `try_into_request`: `{:?}`. This is a bug in axum, please file an issue", + err, + ), + } + } +} + +#[async_trait] +impl FromRequest for Method +where + B: Send, +{ + type Rejection = Infallible; + + async fn from_request(req: &mut RequestParts) -> Result { + Ok(req.method().clone()) + } +} + +#[async_trait] +impl FromRequest for Uri +where + B: Send, +{ + type Rejection = Infallible; + + async fn from_request(req: &mut RequestParts) -> Result { + Ok(req.uri().clone()) + } +} + +#[async_trait] +impl FromRequest for Version +where + B: Send, +{ + type Rejection = Infallible; + + async fn from_request(req: &mut RequestParts) -> Result { + Ok(req.version()) + } +} + +#[async_trait] +impl FromRequest for HeaderMap +where + B: Send, +{ + type Rejection = HeadersAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + req.take_headers().ok_or(HeadersAlreadyExtracted) + } +} + +#[async_trait] +impl FromRequest for Extensions +where + B: Send, +{ + type Rejection = ExtensionsAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + req.take_extensions().ok_or(ExtensionsAlreadyExtracted) + } +} + +#[async_trait] +impl FromRequest for Bytes +where + B: http_body::Body + Send, + B::Data: Send, + B::Error: Into, +{ + type Rejection = BytesRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let body = take_body(req)?; + + let bytes = crate::body::to_bytes(body) + .await + .map_err(FailedToBufferBody::from_err)?; + + Ok(bytes) + } +} + +#[async_trait] +impl FromRequest for String +where + B: http_body::Body + Send, + B::Data: Send, + B::Error: Into, +{ + type Rejection = StringRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let body = take_body(req)?; + + let bytes = crate::body::to_bytes(body) + .await + .map_err(FailedToBufferBody::from_err)? + .to_vec(); + + let string = String::from_utf8(bytes).map_err(InvalidUtf8::from_err)?; + + Ok(string) + } +} + +#[async_trait] +impl FromRequest for http::request::Parts +where + B: Send, +{ + type Rejection = RequestPartsAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + let method = unwrap_infallible(Method::from_request(req).await); + let uri = unwrap_infallible(Uri::from_request(req).await); + let version = unwrap_infallible(Version::from_request(req).await); + let headers = HeaderMap::from_request(req).await?; + let extensions = Extensions::from_request(req).await?; + + let mut temp_request = Request::new(()); + *temp_request.method_mut() = method; + *temp_request.uri_mut() = uri; + *temp_request.version_mut() = version; + *temp_request.headers_mut() = headers; + *temp_request.extensions_mut() = extensions; + + let (parts, _) = temp_request.into_parts(); + + Ok(parts) + } +} + +fn unwrap_infallible(result: Result) -> T { + match result { + Ok(value) => value, + Err(err) => match err {}, + } +} + +pub(crate) fn take_body(req: &mut RequestParts) -> Result { + req.take_body().ok_or(BodyAlreadyExtracted) +} diff --git a/axum/src/extract/tuple.rs b/axum-core/src/extract/tuple.rs similarity index 100% rename from axum/src/extract/tuple.rs rename to axum-core/src/extract/tuple.rs diff --git a/axum-core/src/lib.rs b/axum-core/src/lib.rs new file mode 100644 index 00000000..6d15aade --- /dev/null +++ b/axum-core/src/lib.rs @@ -0,0 +1,63 @@ +//! Core types and traits for [`axum`]. +//! +//! Libraries authors that want to provide [`FromRequest`] or [`IntoResponse`] implementations +//! should depend on the [`axum-core`] crate, instead of `axum` if possible. +//! +//! [`FromRequest`]: crate::extract::FromRequest +//! [`IntoResponse`]: crate::response::IntoResponse +//! [`axum`]: https://crates.io/crates/axum +//! [`axum-core`]: http://crates.io/crates/axum-core + +#![warn( + clippy::all, + clippy::dbg_macro, + clippy::todo, + clippy::empty_enum, + clippy::enum_glob_use, + clippy::mem_forget, + clippy::unused_self, + clippy::filter_map_next, + clippy::needless_continue, + clippy::needless_borrow, + clippy::match_wildcard_for_single_variants, + clippy::if_let_mutex, + clippy::mismatched_target_os, + clippy::await_holding_lock, + clippy::match_on_vec_items, + clippy::imprecise_flops, + clippy::suboptimal_flops, + clippy::lossy_float_literal, + clippy::rest_pat_in_fully_bound_structs, + clippy::fn_params_excessive_bools, + clippy::exit, + clippy::inefficient_to_string, + clippy::linkedlist, + clippy::macro_use_imports, + clippy::option_option, + clippy::verbose_file_reads, + clippy::unnested_or_patterns, + clippy::str_to_string, + rust_2018_idioms, + future_incompatible, + nonstandard_style, + missing_debug_implementations, + missing_docs +)] +#![deny(unreachable_pub, private_in_public)] +#![allow(elided_lifetimes_in_paths, clippy::type_complexity)] +#![forbid(unsafe_code)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(test, allow(clippy::float_cmp))] + +#[macro_use] +pub(crate) mod macros; + +mod error; +pub use self::error::Error; + +pub mod body; +pub mod extract; +pub mod response; + +/// Alias for a type-erased error type. +pub type BoxError = Box; diff --git a/axum-core/src/macros.rs b/axum-core/src/macros.rs new file mode 100644 index 00000000..5e5a76cd --- /dev/null +++ b/axum-core/src/macros.rs @@ -0,0 +1,159 @@ +macro_rules! define_rejection { + ( + #[status = $status:ident] + #[body = $body:expr] + $(#[$m:meta])* + pub struct $name:ident; + ) => { + $(#[$m])* + #[derive(Debug)] + #[non_exhaustive] + pub struct $name; + + #[allow(deprecated)] + impl $crate::response::IntoResponse for $name { + fn into_response(self) -> http::Response<$crate::body::BoxBody> { + let mut res = http::Response::new($crate::body::boxed(http_body::Full::from($body))); + *res.status_mut() = http::StatusCode::$status; + res + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", $body) + } + } + + impl std::error::Error for $name {} + + impl Default for $name { + fn default() -> Self { + Self + } + } + }; + + ( + #[status = $status:ident] + #[body = $body:expr] + $(#[$m:meta])* + pub struct $name:ident (Error); + ) => { + $(#[$m])* + #[derive(Debug)] + pub struct $name(pub(crate) crate::Error); + + impl $name { + pub(crate) fn from_err(err: E) -> Self + where + E: Into, + { + Self(crate::Error::new(err)) + } + } + + impl crate::response::IntoResponse for $name { + fn into_response(self) -> http::Response<$crate::body::BoxBody> { + let body = http_body::Full::from(format!(concat!($body, ": {}"), self.0)); + let body = $crate::body::boxed(body); + let mut res = + http::Response::new(body); + *res.status_mut() = http::StatusCode::$status; + res + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", $body) + } + } + + impl std::error::Error for $name { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.0) + } + } + }; +} + +macro_rules! composite_rejection { + ( + $(#[$m:meta])* + pub enum $name:ident { + $($variant:ident),+ + $(,)? + } + ) => { + $(#[$m])* + #[derive(Debug)] + #[non_exhaustive] + pub enum $name { + $( + #[allow(missing_docs, deprecated)] + $variant($variant) + ),+ + } + + impl $crate::response::IntoResponse for $name { + fn into_response(self) -> http::Response<$crate::body::BoxBody> { + match self { + $( + Self::$variant(inner) => inner.into_response(), + )+ + } + } + } + + $( + #[allow(deprecated)] + impl From<$variant> for $name { + fn from(inner: $variant) -> Self { + Self::$variant(inner) + } + } + )+ + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + $( + Self::$variant(inner) => write!(f, "{}", inner), + )+ + } + } + } + + impl std::error::Error for $name { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + $( + Self::$variant(inner) => Some(inner), + )+ + } + } + } + }; +} + +macro_rules! all_the_tuples { + ($name:ident) => { + $name!(T1); + $name!(T1, T2); + $name!(T1, T2, T3); + $name!(T1, T2, T3, T4); + $name!(T1, T2, T3, T4, T5); + $name!(T1, T2, T3, T4, T5, T6); + $name!(T1, T2, T3, T4, T5, T6, T7); + $name!(T1, T2, T3, T4, T5, T6, T7, T8); + $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9); + $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); + $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); + $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); + $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); + $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); + $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); + $name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); + }; +} diff --git a/axum/src/response/headers.rs b/axum-core/src/response/headers.rs similarity index 97% rename from axum/src/response/headers.rs rename to axum-core/src/response/headers.rs index 7415998f..a82f6a4a 100644 --- a/axum/src/response/headers.rs +++ b/axum-core/src/response/headers.rs @@ -7,7 +7,6 @@ use http::{ }; use http_body::{Empty, Full}; use std::{convert::TryInto, fmt}; -use tower::util::Either; /// A response with headers. /// @@ -146,6 +145,11 @@ where } } +enum Either { + A(A), + B(B), +} + #[cfg(test)] mod tests { use super::*; @@ -172,7 +176,7 @@ mod tests { let res = (Headers(vec![("user-agent", "axum")]), "foo").into_response(); assert_eq!(res.headers()["user-agent"], "axum"); - let body = hyper::body::to_bytes(res.into_body()) + let body = crate::body::to_bytes(res.into_body()) .now_or_never() .unwrap() .unwrap(); @@ -190,7 +194,7 @@ mod tests { assert_eq!(res.headers()["user-agent"], "axum"); assert_eq!(res.status(), StatusCode::NOT_FOUND); - let body = hyper::body::to_bytes(res.into_body()) + let body = crate::body::to_bytes(res.into_body()) .now_or_never() .unwrap() .unwrap(); diff --git a/axum-core/src/response/mod.rs b/axum-core/src/response/mod.rs new file mode 100644 index 00000000..f5d58c32 --- /dev/null +++ b/axum-core/src/response/mod.rs @@ -0,0 +1,359 @@ +//! Types and traits for generating responses. +//! +//! See [`axum::response`] for more details. +//! +//! [`axum::response`]: https://docs.rs/axum/latest/axum/response/index.html + +use crate::{ + body::{boxed, BoxBody}, + BoxError, +}; +use bytes::Bytes; +use http::{ + header::{self, HeaderMap, HeaderValue}, + Response, StatusCode, +}; +use http_body::{ + combinators::{MapData, MapErr}, + Empty, Full, +}; +use std::{borrow::Cow, convert::Infallible}; + +mod headers; + +#[doc(inline)] +pub use self::headers::Headers; + +/// Trait for generating responses. +/// +/// Types that implement `IntoResponse` can be returned from handlers. +/// +/// # Implementing `IntoResponse` +/// +/// You generally shouldn't have to implement `IntoResponse` manually, as axum +/// provides implementations for many common types. +/// +/// However it might be necessary if you have a custom error type that you want +/// to return from handlers: +/// +/// ```rust +/// use axum::{ +/// Router, +/// body::{self, BoxBody, Bytes}, +/// routing::get, +/// http::{Response, StatusCode}, +/// response::IntoResponse, +/// }; +/// +/// enum MyError { +/// SomethingWentWrong, +/// SomethingElseWentWrong, +/// } +/// +/// impl IntoResponse for MyError { +/// fn into_response(self) -> Response { +/// let body = match self { +/// MyError::SomethingWentWrong => { +/// body::boxed(body::Full::from("something went wrong")) +/// }, +/// MyError::SomethingElseWentWrong => { +/// body::boxed(body::Full::from("something else went wrong")) +/// }, +/// }; +/// +/// Response::builder() +/// .status(StatusCode::INTERNAL_SERVER_ERROR) +/// .body(body) +/// .unwrap() +/// } +/// } +/// +/// // `Result` can now be returned from handlers +/// let app = Router::new().route("/", get(handler)); +/// +/// async fn handler() -> Result<(), MyError> { +/// Err(MyError::SomethingWentWrong) +/// } +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// Or if you have a custom body type you'll also need to implement +/// `IntoResponse` for it: +/// +/// ```rust +/// use axum::{ +/// body::{self, BoxBody}, +/// routing::get, +/// response::IntoResponse, +/// Router, +/// }; +/// use http_body::Body; +/// use http::{Response, HeaderMap}; +/// use bytes::Bytes; +/// use std::{ +/// convert::Infallible, +/// task::{Poll, Context}, +/// pin::Pin, +/// }; +/// +/// struct MyBody; +/// +/// // First implement `Body` for `MyBody`. This could for example use +/// // some custom streaming protocol. +/// impl Body for MyBody { +/// type Data = Bytes; +/// type Error = Infallible; +/// +/// fn poll_data( +/// self: Pin<&mut Self>, +/// cx: &mut Context<'_> +/// ) -> Poll>> { +/// # unimplemented!() +/// // ... +/// } +/// +/// fn poll_trailers( +/// self: Pin<&mut Self>, +/// cx: &mut Context<'_> +/// ) -> Poll, Self::Error>> { +/// # unimplemented!() +/// // ... +/// } +/// } +/// +/// // Now we can implement `IntoResponse` directly for `MyBody` +/// impl IntoResponse for MyBody { +/// fn into_response(self) -> Response { +/// Response::new(body::boxed(self)) +/// } +/// } +/// +/// // We don't need to implement `IntoResponse for Response` as that is +/// // covered by a blanket implementation in axum. +/// +/// // `MyBody` can now be returned from handlers. +/// let app = Router::new().route("/", get(|| async { MyBody })); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +pub trait IntoResponse { + /// Create a response. + fn into_response(self) -> Response; +} + +impl IntoResponse for () { + fn into_response(self) -> Response { + Response::new(boxed(Empty::new())) + } +} + +impl IntoResponse for Infallible { + fn into_response(self) -> Response { + match self {} + } +} + +impl IntoResponse for Result +where + T: IntoResponse, + E: IntoResponse, +{ + fn into_response(self) -> Response { + match self { + Ok(value) => value.into_response(), + Err(err) => err.into_response(), + } + } +} + +impl IntoResponse for Response +where + B: http_body::Body + Send + 'static, + B::Error: Into, +{ + fn into_response(self) -> Response { + self.map(boxed) + } +} + +macro_rules! impl_into_response_for_body { + ($body:ty) => { + impl IntoResponse for $body { + fn into_response(self) -> Response { + Response::new(boxed(self)) + } + } + }; +} + +impl_into_response_for_body!(Full); +impl_into_response_for_body!(Empty); + +impl IntoResponse for http::response::Parts { + fn into_response(self) -> Response { + Response::from_parts(self, boxed(Empty::new())) + } +} + +impl IntoResponse for http_body::combinators::BoxBody +where + E: Into + 'static, +{ + fn into_response(self) -> Response { + Response::new(boxed(self)) + } +} + +impl IntoResponse for http_body::combinators::UnsyncBoxBody +where + E: Into + 'static, +{ + fn into_response(self) -> Response { + Response::new(boxed(self)) + } +} + +impl IntoResponse for MapData +where + B: http_body::Body + Send + 'static, + F: FnMut(B::Data) -> Bytes + Send + 'static, + B::Error: Into, +{ + fn into_response(self) -> Response { + Response::new(boxed(self)) + } +} + +impl IntoResponse for MapErr +where + B: http_body::Body + Send + 'static, + F: FnMut(B::Error) -> E + Send + 'static, + E: Into, +{ + fn into_response(self) -> Response { + Response::new(boxed(self)) + } +} + +impl IntoResponse for &'static str { + #[inline] + fn into_response(self) -> Response { + Cow::Borrowed(self).into_response() + } +} + +impl IntoResponse for String { + #[inline] + fn into_response(self) -> Response { + Cow::<'static, str>::Owned(self).into_response() + } +} + +impl IntoResponse for Cow<'static, str> { + fn into_response(self) -> Response { + let mut res = Response::new(boxed(Full::from(self))); + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), + ); + res + } +} + +impl IntoResponse for Bytes { + fn into_response(self) -> Response { + let mut res = Response::new(boxed(Full::from(self))); + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), + ); + res + } +} + +impl IntoResponse for &'static [u8] { + fn into_response(self) -> Response { + let mut res = Response::new(boxed(Full::from(self))); + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), + ); + res + } +} + +impl IntoResponse for Vec { + fn into_response(self) -> Response { + let mut res = Response::new(boxed(Full::from(self))); + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), + ); + res + } +} + +impl IntoResponse for Cow<'static, [u8]> { + fn into_response(self) -> Response { + let mut res = Response::new(boxed(Full::from(self))); + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), + ); + res + } +} + +impl IntoResponse for StatusCode { + fn into_response(self) -> Response { + Response::builder() + .status(self) + .body(boxed(Empty::new())) + .unwrap() + } +} + +impl IntoResponse for (StatusCode, T) +where + T: IntoResponse, +{ + fn into_response(self) -> Response { + let mut res = self.1.into_response(); + *res.status_mut() = self.0; + res + } +} + +impl IntoResponse for (HeaderMap, T) +where + T: IntoResponse, +{ + fn into_response(self) -> Response { + let mut res = self.1.into_response(); + res.headers_mut().extend(self.0); + res + } +} + +impl IntoResponse for (StatusCode, HeaderMap, T) +where + T: IntoResponse, +{ + fn into_response(self) -> Response { + let mut res = self.2.into_response(); + *res.status_mut() = self.0; + res.headers_mut().extend(self.1); + res + } +} + +impl IntoResponse for HeaderMap { + fn into_response(self) -> Response { + let mut res = Response::new(boxed(Empty::new())); + *res.headers_mut() = self; + res + } +} diff --git a/axum-debug/src/lib.rs b/axum-debug/src/lib.rs index e35c0ed4..36da6bc8 100644 --- a/axum-debug/src/lib.rs +++ b/axum-debug/src/lib.rs @@ -237,7 +237,7 @@ mod debug_handler { #[allow(warnings)] fn #name() where - #ty: ::axum::extract::FromRequest + Send, + #ty: ::axum::extract::FromRequest<::axum::body::Body> + Send, {} } }) diff --git a/axum-debug/tests/fail/argument_not_extractor.stderr b/axum-debug/tests/fail/argument_not_extractor.stderr index 5ee9a59a..f28dffbc 100644 --- a/axum-debug/tests/fail/argument_not_extractor.stderr +++ b/axum-debug/tests/fail/argument_not_extractor.stderr @@ -1,7 +1,7 @@ -error[E0277]: the trait bound `bool: FromRequest` is not satisfied +error[E0277]: the trait bound `bool: FromRequest` is not satisfied --> tests/fail/argument_not_extractor.rs:4:23 | 4 | async fn handler(foo: bool) {} - | ^^^^ the trait `FromRequest` is not implemented for `bool` + | ^^^^ the trait `FromRequest` is not implemented for `bool` | = help: see issue #48214 diff --git a/axum-debug/tests/fail/extract_self_mut.rs b/axum-debug/tests/fail/extract_self_mut.rs index cc291267..0954136d 100644 --- a/axum-debug/tests/fail/extract_self_mut.rs +++ b/axum-debug/tests/fail/extract_self_mut.rs @@ -7,10 +7,13 @@ use axum_debug::debug_handler; struct A; #[async_trait] -impl FromRequest for A { +impl FromRequest for A +where + B: Send + 'static, +{ type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-debug/tests/fail/extract_self_mut.stderr b/axum-debug/tests/fail/extract_self_mut.stderr index a9b604bc..e0a401e4 100644 --- a/axum-debug/tests/fail/extract_self_mut.stderr +++ b/axum-debug/tests/fail/extract_self_mut.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/fail/extract_self_mut.rs:20:22 + --> tests/fail/extract_self_mut.rs:23:22 | -20 | async fn handler(&mut self) {} +23 | async fn handler(&mut self) {} | ^^^^^^^^^ diff --git a/axum-debug/tests/fail/extract_self_ref.rs b/axum-debug/tests/fail/extract_self_ref.rs index a9d025f7..f8b3e0ed 100644 --- a/axum-debug/tests/fail/extract_self_ref.rs +++ b/axum-debug/tests/fail/extract_self_ref.rs @@ -7,10 +7,13 @@ use axum_debug::debug_handler; struct A; #[async_trait] -impl FromRequest for A { +impl FromRequest for A +where + B: Send + 'static, +{ type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-debug/tests/fail/extract_self_ref.stderr b/axum-debug/tests/fail/extract_self_ref.stderr index 7f628bc6..8c5e2548 100644 --- a/axum-debug/tests/fail/extract_self_ref.stderr +++ b/axum-debug/tests/fail/extract_self_ref.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/fail/extract_self_ref.rs:20:22 + --> tests/fail/extract_self_ref.rs:23:22 | -20 | async fn handler(&self) {} +23 | async fn handler(&self) {} | ^^^^^ diff --git a/axum-debug/tests/pass/returns_self.rs b/axum-debug/tests/pass/returns_self.rs index 449e8a15..eb520fbd 100644 --- a/axum-debug/tests/pass/returns_self.rs +++ b/axum-debug/tests/pass/returns_self.rs @@ -4,7 +4,6 @@ use axum::{ response::IntoResponse, }; use axum_debug::debug_handler; -use std::convert::Infallible; struct A; diff --git a/axum-debug/tests/pass/self_receiver.rs b/axum-debug/tests/pass/self_receiver.rs index 3f5b68c0..c9d23d25 100644 --- a/axum-debug/tests/pass/self_receiver.rs +++ b/axum-debug/tests/pass/self_receiver.rs @@ -7,10 +7,13 @@ use axum_debug::debug_handler; struct A; #[async_trait] -impl FromRequest for A { +impl FromRequest for A +where + B: Send + 'static, +{ type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index d274db51..75f9bf81 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -13,6 +13,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `MethodRouter::route_layer`. - Merge method routers with `MethodRouter::merge` - Customize response for unsupported methods with `MethodRouter::fallback` +- **breaking:** The default for the type parameter in `FromRequest` and + `RequestParts` has been removed. Use `FromRequest` and + `RequestParts` to get the previous behavior ([#564]) +- **added:** `FromRequest` and `IntoResponse` are now defined in a new called + `axum-core`. This crate is intended for library authors to depend on, rather + than `axum` itself, if possible. `axum-core` has a smaller API and will thus + receive fewer breaking changes. `FromRequest` and `IntoResponse` are + re-exported from `axum` in the same location so nothing is changed for `axum` + users ([#564]) - **breaking:** The previously deprecated `axum::body::box_body` function has been removed. Use `axum::body::boxed` instead. - **fixed:** Adding the same route with different methods now works ie @@ -43,6 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#529]: https://github.com/tokio-rs/axum/pull/529 [#534]: https://github.com/tokio-rs/axum/pull/534 [#554]: https://github.com/tokio-rs/axum/pull/554 +[#564]: https://github.com/tokio-rs/axum/pull/564 [#571]: https://github.com/tokio-rs/axum/pull/571 # 0.3.3 (13. November, 2021) diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 5548406f..12c87f11 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -20,6 +20,7 @@ tower-log = ["tower/log"] ws = ["tokio-tungstenite", "sha-1", "base64"] [dependencies] +axum-core = { path = "../axum-core", version = "0.1" } async-trait = "0.1.43" bitflags = "1.0" bytes = "1.0" diff --git a/axum/src/body/mod.rs b/axum/src/body/mod.rs index 6b449cf4..d5c5f661 100644 --- a/axum/src/body/mod.rs +++ b/axum/src/body/mod.rs @@ -1,7 +1,5 @@ //! HTTP body utilities. -use crate::{util::try_downcast, BoxError, Error}; - mod stream_body; pub use self::stream_body::StreamBody; @@ -15,20 +13,8 @@ pub use hyper::body::Body; #[doc(no_inline)] pub use bytes::Bytes; -/// A boxed [`Body`] trait object. -/// -/// This is used in axum as the response body type for applications. It's -/// necessary to unify multiple response bodies types into one. -pub type BoxBody = http_body::combinators::UnsyncBoxBody; - -/// Convert a [`http_body::Body`] into a [`BoxBody`]. -pub fn boxed(body: B) -> BoxBody -where - B: http_body::Body + Send + 'static, - B::Error: Into, -{ - try_downcast(body).unwrap_or_else(|body| body.map_err(Error::new).boxed_unsync()) -} +#[doc(inline)] +pub use axum_core::body::{boxed, BoxBody}; pub(crate) fn empty() -> BoxBody { boxed(http_body::Empty::new()) diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index 97cae977..4d40f1dd 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -240,8 +240,8 @@ async fn create_user(payload: Result, JsonRejection>) { Err(JsonRejection::InvalidJsonBody(_)) => { // Couldn't deserialize the body into the target type } - Err(JsonRejection::BodyAlreadyExtracted(_)) => { - // Another extractor had already consumed the body + Err(JsonRejection::BytesRejection(_)) => { + // Failed to extract the request body } Err(_) => { // `JsonRejection` is marked `#[non_exhaustive]` so match must @@ -316,9 +316,9 @@ async fn handler(result: Result, JsonRejection>) -> impl IntoRespons StatusCode::BAD_REQUEST, "Missing `Content-Type: application/json` header".to_string(), )), - JsonRejection::BodyAlreadyExtracted(_) => Err(( + JsonRejection::BytesRejection(_) => Err(( StatusCode::INTERNAL_SERVER_ERROR, - "Body already extracted".to_string(), + "Failed to buffer request body".to_string(), )), JsonRejection::HeadersAlreadyExtracted(_) => Err(( StatusCode::INTERNAL_SERVER_ERROR, @@ -474,3 +474,5 @@ let app = Router::new() [`body::Body`]: crate::body::Body [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs +[`HeaderMap`]: https://docs.rs/http/latest/http/header/struct.HeaderMap.html +[`Request`]: https://docs.rs/http/latest/http/struct.Request.html diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index 81d2cf65..a513094a 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -41,9 +41,11 @@ where async fn from_request(req: &mut RequestParts) -> Result { let content_length = req .headers() - .ok_or(ContentLengthLimitRejection::HeadersAlreadyExtracted( - HeadersAlreadyExtracted, - ))? + .ok_or_else(|| { + ContentLengthLimitRejection::HeadersAlreadyExtracted( + HeadersAlreadyExtracted::default(), + ) + })? .get(http::header::CONTENT_LENGTH); let content_length = diff --git a/axum/src/extract/extension.rs b/axum/src/extract/extension.rs index 762af024..d65fb3a3 100644 --- a/axum/src/extract/extension.rs +++ b/axum/src/extract/extension.rs @@ -53,7 +53,7 @@ where async fn from_request(req: &mut RequestParts) -> Result { let value = req .extensions() - .ok_or(ExtensionsAlreadyExtracted)? + .ok_or_else(ExtensionsAlreadyExtracted::default)? .get::() .ok_or_else(|| { MissingExtension::from_err(format!( diff --git a/axum/src/extract/form.rs b/axum/src/extract/form.rs index a2ffeaa7..ed3b05cc 100644 --- a/axum/src/extract/form.rs +++ b/axum/src/extract/form.rs @@ -1,7 +1,6 @@ -use super::{has_content_type, rejection::*, take_body, FromRequest, RequestParts}; +use super::{has_content_type, rejection::*, FromRequest, RequestParts}; use crate::BoxError; use async_trait::async_trait; -use bytes::Buf; use http::Method; use serde::de::DeserializeOwned; use std::ops::Deref; @@ -64,11 +63,8 @@ where return Err(InvalidFormContentType.into()); } - let body = take_body(req)?; - let chunks = hyper::body::aggregate(body) - .await - .map_err(FailedToBufferBody::from_err)?; - let value = serde_urlencoded::from_reader(chunks.reader()) + let bytes = bytes::Bytes::from_request(req).await?; + let value = serde_urlencoded::from_bytes(&bytes) .map_err(FailedToDeserializeQueryString::new::)?; Ok(Form(value)) diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 7218393b..9ef519d1 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -70,11 +70,9 @@ where type Rejection = MatchedPathRejection; async fn from_request(req: &mut RequestParts) -> Result { - let extensions = - req.extensions() - .ok_or(MatchedPathRejection::ExtensionsAlreadyExtracted( - ExtensionsAlreadyExtracted, - ))?; + let extensions = req.extensions().ok_or_else(|| { + MatchedPathRejection::ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted::default()) + })?; let matched_path = extensions .get::() @@ -114,7 +112,7 @@ mod tests { .get::() .unwrap() .as_str() - .to_string(); + .to_owned(); req.extensions_mut().insert(MatchedPathFromMiddleware(path)); self.0.call(req) } @@ -127,7 +125,7 @@ mod tests { async fn access_matched_path() { let api = Router::new().route( "/users/:id", - get(|path: MatchedPath| async move { path.as_str().to_string() }), + get(|path: MatchedPath| async move { path.as_str().to_owned() }), ); async fn handler( @@ -146,7 +144,7 @@ mod tests { let app = Router::new() .route( "/:key", - get(|path: MatchedPath| async move { path.as_str().to_string() }), + get(|path: MatchedPath| async move { path.as_str().to_owned() }), ) .nest("/api", api) .nest( diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 9b89570f..ab7ff2c3 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -1,10 +1,8 @@ #![doc = include_str!("../docs/extract.md")] -use crate::{response::IntoResponse, Error}; -use async_trait::async_trait; -use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version}; +use crate::response::IntoResponse; +use http::header; use rejection::*; -use std::convert::Infallible; pub mod connect_info; pub mod extractor_middleware; @@ -22,7 +20,9 @@ mod path; mod query; mod raw_query; mod request_parts; -mod tuple; + +#[doc(inline)] +pub use axum_core::extract::{FromRequest, RequestParts}; #[doc(inline)] #[allow(deprecated)] @@ -66,277 +66,13 @@ mod typed_header; #[doc(inline)] pub use self::typed_header::TypedHeader; -/// Types that can be created from requests. -/// -/// See the [module docs](crate::extract) for more details. -/// -/// # What is the `B` type parameter? -/// -/// `FromRequest` is generic over the request body (the `B` in -/// [`http::Request`]). This is to allow `FromRequest` to be usable with any -/// type of request body. This is necessary because some middleware change the -/// request body, for example to add timeouts. -/// -/// If you're writing your own `FromRequest` that wont be used outside your -/// application, and not using any middleware that changes the request body, you -/// can most likely use `axum::body::Body`. Note that this is also the default. -/// -/// If you're writing a library that's intended for others to use, it's recommended -/// to keep the generic type parameter: -/// -/// ```rust -/// use axum::{ -/// async_trait, -/// extract::{FromRequest, RequestParts}, -/// }; -/// -/// struct MyExtractor; -/// -/// #[async_trait] -/// impl FromRequest for MyExtractor -/// where -/// B: Send, // required by `async_trait` -/// { -/// type Rejection = http::StatusCode; -/// -/// async fn from_request(req: &mut RequestParts) -> Result { -/// // ... -/// # unimplemented!() -/// } -/// } -/// ``` -/// -/// This ensures your extractor is as flexible as possible. -/// -/// [`http::Request`]: http::Request -#[async_trait] -pub trait FromRequest: Sized { - /// If the extractor fails it'll use this "rejection" type. A rejection is - /// a kind of error that can be converted into a response. - type Rejection: IntoResponse; - - /// Perform the extraction. - async fn from_request(req: &mut RequestParts) -> Result; -} - -/// The type used with [`FromRequest`] to extract data from requests. -/// -/// Has several convenience methods for getting owned parts of the request. -#[derive(Debug)] -pub struct RequestParts { - method: Method, - uri: Uri, - version: Version, - headers: Option, - extensions: Option, - body: Option, -} - -impl RequestParts { - /// Create a new `RequestParts`. - /// - /// You generally shouldn't need to construct this type yourself, unless - /// using extractors outside of axum for example to implement a - /// [`tower::Service`]. - pub fn new(req: Request) -> Self { - let ( - http::request::Parts { - method, - uri, - version, - headers, - extensions, - .. - }, - body, - ) = req.into_parts(); - - RequestParts { - method, - uri, - version, - headers: Some(headers), - extensions: Some(extensions), - body: Some(body), - } - } - - /// Convert this `RequestParts` back into a [`Request`]. - /// - /// Fails if - /// - /// - The full [`HeaderMap`] has been extracted, that is [`take_headers`] - /// have been called. - /// - The full [`Extensions`] has been extracted, that is - /// [`take_extensions`] have been called. - /// - The request body has been extracted, that is [`take_body`] have been - /// called. - /// - /// [`take_headers`]: RequestParts::take_headers - /// [`take_extensions`]: RequestParts::take_extensions - /// [`take_body`]: RequestParts::take_body - pub fn try_into_request(self) -> Result, Error> { - let Self { - method, - uri, - version, - mut headers, - mut extensions, - mut body, - } = self; - - let mut req = if let Some(body) = body.take() { - Request::new(body) - } else { - return Err(Error::new(RequestAlreadyExtracted::BodyAlreadyExtracted( - BodyAlreadyExtracted, - ))); - }; - - *req.method_mut() = method; - *req.uri_mut() = uri; - *req.version_mut() = version; - - if let Some(headers) = headers.take() { - *req.headers_mut() = headers; - } else { - return Err(Error::new( - RequestAlreadyExtracted::HeadersAlreadyExtracted(HeadersAlreadyExtracted), - )); - } - - if let Some(extensions) = extensions.take() { - *req.extensions_mut() = extensions; - } else { - return Err(Error::new( - RequestAlreadyExtracted::ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted), - )); - } - - Ok(req) - } - - /// Gets a reference the request method. - pub fn method(&self) -> &Method { - &self.method - } - - /// Gets a mutable reference to the request method. - pub fn method_mut(&mut self) -> &mut Method { - &mut self.method - } - - /// Gets a reference the request URI. - pub fn uri(&self) -> &Uri { - &self.uri - } - - /// Gets a mutable reference to the request URI. - pub fn uri_mut(&mut self) -> &mut Uri { - &mut self.uri - } - - /// Get the request HTTP version. - pub fn version(&self) -> Version { - self.version - } - - /// Gets a mutable reference to the request HTTP version. - pub fn version_mut(&mut self) -> &mut Version { - &mut self.version - } - - /// Gets a reference to the request headers. - /// - /// Returns `None` if the headers has been taken by another extractor. - pub fn headers(&self) -> Option<&HeaderMap> { - self.headers.as_ref() - } - - /// Gets a mutable reference to the request headers. - /// - /// Returns `None` if the headers has been taken by another extractor. - pub fn headers_mut(&mut self) -> Option<&mut HeaderMap> { - self.headers.as_mut() - } - - /// Takes the headers out of the request, leaving a `None` in its place. - pub fn take_headers(&mut self) -> Option { - self.headers.take() - } - - /// Gets a reference to the request extensions. - /// - /// Returns `None` if the extensions has been taken by another extractor. - pub fn extensions(&self) -> Option<&Extensions> { - self.extensions.as_ref() - } - - /// Gets a mutable reference to the request extensions. - /// - /// Returns `None` if the extensions has been taken by another extractor. - pub fn extensions_mut(&mut self) -> Option<&mut Extensions> { - self.extensions.as_mut() - } - - /// Takes the extensions out of the request, leaving a `None` in its place. - pub fn take_extensions(&mut self) -> Option { - self.extensions.take() - } - - /// Gets a reference to the request body. - /// - /// Returns `None` if the body has been taken by another extractor. - pub fn body(&self) -> Option<&B> { - self.body.as_ref() - } - - /// Gets a mutable reference to the request body. - /// - /// Returns `None` if the body has been taken by another extractor. - pub fn body_mut(&mut self) -> Option<&mut B> { - self.body.as_mut() - } - - /// Takes the body out of the request, leaving a `None` in its place. - pub fn take_body(&mut self) -> Option { - self.body.take() - } -} - -#[async_trait] -impl FromRequest for Option -where - T: FromRequest, - B: Send, -{ - type Rejection = Infallible; - - async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { - Ok(T::from_request(req).await.ok()) - } -} - -#[async_trait] -impl FromRequest for Result -where - T: FromRequest, - B: Send, -{ - type Rejection = Infallible; - - async fn from_request(req: &mut RequestParts) -> Result { - Ok(T::from_request(req).await) - } -} - pub(crate) fn has_content_type( req: &RequestParts, expected_content_type: &mime::Mime, ) -> Result { let content_type = if let Some(content_type) = req .headers() - .ok_or(HeadersAlreadyExtracted)? + .ok_or_else(HeadersAlreadyExtracted::default)? .get(header::CONTENT_TYPE) { content_type @@ -354,7 +90,7 @@ pub(crate) fn has_content_type( } pub(crate) fn take_body(req: &mut RequestParts) -> Result { - req.take_body().ok_or(BodyAlreadyExtracted) + req.take_body().ok_or_else(BodyAlreadyExtracted::default) } #[cfg(test)] diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index e1d7104c..3cd6c692 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -59,7 +59,7 @@ where async fn from_request(req: &mut RequestParts) -> Result { let stream = BodyStream::from_request(req).await?; - let headers = req.headers().ok_or(HeadersAlreadyExtracted)?; + let headers = req.headers().ok_or_else(HeadersAlreadyExtracted::default)?; let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?; let multipart = multer::Multipart::new(stream, boundary); Ok(Self { inner: multipart }) diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index 3c944a38..e3c437ee 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -616,7 +616,7 @@ mod tests { let url_params = create_url_params(vec![("a", "1"), ("b", "2")]); assert_eq!( i32::deserialize(PathDeserializer::new(&url_params)).unwrap_err(), - PathDeserializerError::custom("wrong number of parameters: 2 expected 1".to_string()) + PathDeserializerError::custom("wrong number of parameters: 2 expected 1".to_owned()) ); } @@ -625,14 +625,14 @@ mod tests { let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]); assert_eq!( <(i32, bool, String)>::deserialize(PathDeserializer::new(&url_params)).unwrap(), - (1, true, "abc".to_string()) + (1, true, "abc".to_owned()) ); #[derive(Debug, Deserialize, Eq, PartialEq)] struct TupleStruct(i32, bool, String); assert_eq!( TupleStruct::deserialize(PathDeserializer::new(&url_params)).unwrap(), - TupleStruct(1, true, "abc".to_string()) + TupleStruct(1, true, "abc".to_owned()) ); let url_params = create_url_params(vec![("a", "1"), ("b", "2"), ("c", "3")]); @@ -654,7 +654,7 @@ mod tests { assert_eq!( Struct::deserialize(PathDeserializer::new(&url_params)).unwrap(), Struct { - c: "abc".to_string(), + c: "abc".to_owned(), b: true, a: 1, } @@ -668,7 +668,7 @@ mod tests { >::deserialize(PathDeserializer::new(&url_params)).unwrap(), [("a", "1"), ("b", "true"), ("c", "abc")] .iter() - .map(|(key, value)| ((*key).to_string(), (*value).to_string())) + .map(|(key, value)| ((*key).to_owned(), (*value).to_owned())) .collect() ); } diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index 40c0d467..ed73d14b 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -7,20 +7,7 @@ use crate::{ }; use http_body::Full; -define_rejection! { - #[status = INTERNAL_SERVER_ERROR] - #[body = "Extensions taken by other extractor"] - /// Rejection used if the request extension has been taken by another - /// extractor. - pub struct ExtensionsAlreadyExtracted; -} - -define_rejection! { - #[status = INTERNAL_SERVER_ERROR] - #[body = "Headers taken by other extractor"] - /// Rejection used if the headers has been taken by another extractor. - pub struct HeadersAlreadyExtracted; -} +pub use axum_core::extract::rejection::*; #[cfg(feature = "json")] define_rejection! { @@ -47,22 +34,6 @@ define_rejection! { pub struct MissingExtension(Error); } -define_rejection! { - #[status = BAD_REQUEST] - #[body = "Failed to buffer the request body"] - /// Rejection type for extractors that buffer the request body. Used if the - /// request body cannot be buffered due to an error. - pub struct FailedToBufferBody(Error); -} - -define_rejection! { - #[status = BAD_REQUEST] - #[body = "Request body didn't contain valid UTF-8"] - /// Rejection type used when buffering the request into a [`String`] if the - /// body doesn't contain valid UTF-8. - pub struct InvalidUtf8(Error); -} - define_rejection! { #[status = PAYLOAD_TOO_LARGE] #[body = "Request payload is too large"] @@ -86,14 +57,6 @@ define_rejection! { pub struct MissingRouteParams; } -define_rejection! { - #[status = INTERNAL_SERVER_ERROR] - #[body = "Cannot have two request body extractors for a single handler"] - /// Rejection type used if you try and extract the request body more than - /// once. - pub struct BodyAlreadyExtracted; -} - define_rejection! { #[status = BAD_REQUEST] #[body = "Form requests must have `Content-Type: x-www-form-urlencoded`"] @@ -186,8 +149,7 @@ composite_rejection! { pub enum FormRejection { InvalidFormContentType, FailedToDeserializeQueryString, - FailedToBufferBody, - BodyAlreadyExtracted, + BytesRejection, HeadersAlreadyExtracted, } } @@ -202,7 +164,7 @@ composite_rejection! { pub enum JsonRejection { InvalidJsonBody, MissingJsonContentType, - BodyAlreadyExtracted, + BytesRejection, HeadersAlreadyExtracted, } } @@ -229,51 +191,6 @@ composite_rejection! { } } -composite_rejection! { - /// Rejection used for [`Bytes`](bytes::Bytes). - /// - /// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor - /// can fail. - pub enum BytesRejection { - BodyAlreadyExtracted, - FailedToBufferBody, - } -} - -composite_rejection! { - /// Rejection used for [`String`]. - /// - /// Contains one variant for each way the [`String`] extractor can fail. - pub enum StringRejection { - BodyAlreadyExtracted, - FailedToBufferBody, - InvalidUtf8, - } -} - -composite_rejection! { - /// Rejection used for [`Request<_>`]. - /// - /// Contains one variant for each way the [`Request<_>`] extractor can fail. - /// - /// [`Request<_>`]: http::Request - pub enum RequestAlreadyExtracted { - BodyAlreadyExtracted, - HeadersAlreadyExtracted, - ExtensionsAlreadyExtracted, - } -} - -composite_rejection! { - /// Rejection used for [`http::request::Parts`]. - /// - /// Contains one variant for each way the [`http::request::Parts`] extractor can fail. - pub enum RequestPartsAlreadyExtracted { - HeadersAlreadyExtracted, - ExtensionsAlreadyExtracted, - } -} - define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "No matched path found"] diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index 560dc0b7..ce24f800 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -3,7 +3,7 @@ use crate::{body::Body, BoxError, Error}; use async_trait::async_trait; use bytes::Bytes; use futures_util::stream::Stream; -use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; +use http::Uri; use http_body::Body as HttpBody; use std::{ convert::Infallible, @@ -13,78 +13,6 @@ use std::{ }; use sync_wrapper::SyncWrapper; -#[async_trait] -impl FromRequest for Request -where - B: Send, -{ - type Rejection = RequestAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - let req = std::mem::replace( - req, - RequestParts { - method: req.method.clone(), - version: req.version, - uri: req.uri.clone(), - headers: None, - extensions: None, - body: None, - }, - ); - - let err = match req.try_into_request() { - Ok(req) => return Ok(req), - Err(err) => err, - }; - - match err.downcast::() { - Ok(err) => return Err(err), - Err(err) => unreachable!( - "Unexpected error type from `try_into_request`: `{:?}`. This is a bug in axum, please file an issue", - err, - ), - } - } -} - -#[async_trait] -impl FromRequest for RawBody -where - B: Send, -{ - type Rejection = BodyAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; - Ok(Self(body)) - } -} - -#[async_trait] -impl FromRequest for Method -where - B: Send, -{ - type Rejection = Infallible; - - async fn from_request(req: &mut RequestParts) -> Result { - Ok(req.method().clone()) - } -} - -#[async_trait] -impl FromRequest for Uri -where - B: Send, -{ - type Rejection = Infallible; - - async fn from_request(req: &mut RequestParts) -> Result { - Ok(req.uri().clone()) - } -} - /// Extractor that gets the original request URI regardless of nesting. /// /// This is necessary since [`Uri`](http::Uri), when used as an extractor, will @@ -133,42 +61,6 @@ where } } -#[async_trait] -impl FromRequest for Version -where - B: Send, -{ - type Rejection = Infallible; - - async fn from_request(req: &mut RequestParts) -> Result { - Ok(req.version()) - } -} - -#[async_trait] -impl FromRequest for HeaderMap -where - B: Send, -{ - type Rejection = HeadersAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - req.take_headers().ok_or(HeadersAlreadyExtracted) - } -} - -#[async_trait] -impl FromRequest for Extensions -where - B: Send, -{ - type Rejection = ExtensionsAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - req.take_extensions().ok_or(ExtensionsAlreadyExtracted) - } -} - /// Extractor that extracts the request body as a [`Stream`]. /// /// Note if your request body is [`body::Body`] you can extract that directly @@ -272,101 +164,27 @@ fn body_stream_traits() { pub struct RawBody(pub B); #[async_trait] -impl FromRequest for Bytes -where - B: http_body::Body + Send, - B::Data: Send, - B::Error: Into, -{ - type Rejection = BytesRejection; - - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; - - let bytes = hyper::body::to_bytes(body) - .await - .map_err(FailedToBufferBody::from_err)?; - - Ok(bytes) - } -} - -#[async_trait] -impl FromRequest for Body { - type Rejection = BodyAlreadyExtracted; - - async fn from_request(req: &mut RequestParts) -> Result { - req.take_body().ok_or(BodyAlreadyExtracted) - } -} - -#[async_trait] -impl FromRequest for String -where - B: http_body::Body + Send, - B::Data: Send, - B::Error: Into, -{ - type Rejection = StringRejection; - - async fn from_request(req: &mut RequestParts) -> Result { - let body = take_body(req)?; - - let bytes = hyper::body::to_bytes(body) - .await - .map_err(FailedToBufferBody::from_err)? - .to_vec(); - - let string = String::from_utf8(bytes).map_err(InvalidUtf8::from_err)?; - - Ok(string) - } -} - -#[async_trait] -impl FromRequest for http::request::Parts +impl FromRequest for RawBody where B: Send, { - type Rejection = RequestPartsAlreadyExtracted; + type Rejection = BodyAlreadyExtracted; async fn from_request(req: &mut RequestParts) -> Result { - let method = unwrap_infallible(Method::from_request(req).await); - let uri = unwrap_infallible(Uri::from_request(req).await); - let version = unwrap_infallible(Version::from_request(req).await); - let headers = HeaderMap::from_request(req).await?; - let extensions = Extensions::from_request(req).await?; - - let mut temp_request = Request::new(()); - *temp_request.method_mut() = method; - *temp_request.uri_mut() = uri; - *temp_request.version_mut() = version; - *temp_request.headers_mut() = headers; - *temp_request.extensions_mut() = extensions; - - let (parts, _) = temp_request.into_parts(); - - Ok(parts) - } -} - -fn unwrap_infallible(result: Result) -> T { - match result { - Ok(value) => value, - Err(err) => match err {}, + let body = take_body(req)?; + Ok(Self(body)) } } #[cfg(test)] mod tests { - use super::*; use crate::{ body::Body, routing::{get, post}, test_helpers::*, AddExtensionLayer, Router, }; - use http::StatusCode; + use http::{Method, Request, StatusCode}; #[tokio::test] async fn multiple_request_extractors() { diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index b98f8c4d..0b5e1a69 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -218,7 +218,7 @@ where let sec_websocket_key = if let Some(key) = req .headers_mut() - .ok_or(HeadersAlreadyExtracted)? + .ok_or_else(HeadersAlreadyExtracted::default)? .remove(header::SEC_WEBSOCKET_KEY) { key @@ -228,13 +228,13 @@ where let on_upgrade = req .extensions_mut() - .ok_or(ExtensionsAlreadyExtracted)? + .ok_or_else(ExtensionsAlreadyExtracted::default)? .remove::() .unwrap(); let sec_websocket_protocol = req .headers() - .ok_or(HeadersAlreadyExtracted)? + .ok_or_else(HeadersAlreadyExtracted::default)? .get(header::SEC_WEBSOCKET_PROTOCOL) .cloned(); @@ -253,7 +253,11 @@ fn header_eq( key: HeaderName, value: &'static str, ) -> Result { - if let Some(header) = req.headers().ok_or(HeadersAlreadyExtracted)?.get(&key) { + if let Some(header) = req + .headers() + .ok_or_else(HeadersAlreadyExtracted::default)? + .get(&key) + { Ok(header.as_bytes().eq_ignore_ascii_case(value.as_bytes())) } else { Ok(false) @@ -265,7 +269,11 @@ fn header_contains( key: HeaderName, value: &'static str, ) -> Result { - let header = if let Some(header) = req.headers().ok_or(HeadersAlreadyExtracted)?.get(&key) { + let header = if let Some(header) = req + .headers() + .ok_or_else(HeadersAlreadyExtracted::default)? + .get(&key) + { header } else { return Ok(false); diff --git a/axum/src/json.rs b/axum/src/json.rs index c161c5a3..5c5c88ec 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -1,6 +1,6 @@ use crate::{ body::{self, BoxBody}, - extract::{rejection::*, take_body, FromRequest, RequestParts}, + extract::{rejection::*, FromRequest, RequestParts}, response::IntoResponse, BoxError, }; @@ -98,16 +98,10 @@ where type Rejection = JsonRejection; async fn from_request(req: &mut RequestParts) -> Result { - use bytes::Buf; - if json_content_type(req)? { - let body = take_body(req)?; + let bytes = bytes::Bytes::from_request(req).await?; - let buf = hyper::body::aggregate(body) - .await - .map_err(InvalidJsonBody::from_err)?; - - let value = serde_json::from_reader(buf.reader()).map_err(InvalidJsonBody::from_err)?; + let value = serde_json::from_slice(&bytes).map_err(InvalidJsonBody::from_err)?; Ok(Json(value)) } else { @@ -119,7 +113,7 @@ where fn json_content_type(req: &RequestParts) -> Result { let content_type = if let Some(content_type) = req .headers() - .ok_or(HeadersAlreadyExtracted)? + .ok_or_else(HeadersAlreadyExtracted::default)? .get(header::CONTENT_TYPE) { content_type diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 953310e4..775c833c 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -13,6 +13,7 @@ //! - [Middleware](#middleware) //! - [Routing to services and backpressure](#routing-to-services-and-backpressure) //! - [Sharing state with handlers](#sharing-state-with-handlers) +//! - [Building integrations for axum](#building-integrations-for-axum) //! - [Required dependencies](#required-dependencies) //! - [Examples](#examples) //! - [Feature flags](#feature-flags) @@ -259,6 +260,12 @@ //! # }; //! ``` //! +//! # Building integrations for axum +//! +//! Libraries authors that want to provide [`FromRequest`] or [`IntoResponse`] implementations +//! should depend on the [`axum-core`] crate, instead of `axum` if possible. [`axum-core`] contains +//! core types and traits and is less likely to receive breaking changes. +//! //! # Required dependencies //! //! To use axum there are a few dependencies you have pull in as well: @@ -331,6 +338,7 @@ //! [`Handler`]: crate::handler::Handler //! [`Infallible`]: std::convert::Infallible //! [load shed]: tower::load_shed +//! [`axum-core`]: http://crates.io/crates/axum-core #![warn( clippy::all, @@ -377,7 +385,6 @@ pub(crate) mod macros; mod add_extension; -mod error; #[cfg(feature = "json")] mod json; mod util; @@ -404,7 +411,7 @@ pub use hyper::Server; #[cfg(feature = "json")] pub use self::json::Json; #[doc(inline)] -pub use self::{error::Error, routing::Router}; +pub use self::routing::Router; -/// Alias for a type-erased error type. -pub type BoxError = Box; +#[doc(inline)] +pub use axum_core::{BoxError, Error}; diff --git a/axum/src/response/mod.rs b/axum/src/response/mod.rs index f551a826..1ac275f4 100644 --- a/axum/src/response/mod.rs +++ b/axum/src/response/mod.rs @@ -1,18 +1,10 @@ #![doc = include_str!("../docs/response.md")] -use crate::{ - body::{boxed, BoxBody}, - BoxError, -}; +use axum_core::body::{boxed, BoxBody}; use bytes::Bytes; -use http::{header, HeaderMap, HeaderValue, Response, StatusCode}; -use http_body::{ - combinators::{MapData, MapErr}, - Empty, Full, -}; -use std::{borrow::Cow, convert::Infallible}; +use http::{header, HeaderValue, Response}; +use http_body::Full; -mod headers; mod redirect; pub mod sse; @@ -22,342 +14,10 @@ pub mod sse; pub use crate::Json; #[doc(inline)] -pub use self::{headers::Headers, redirect::Redirect, sse::Sse}; +pub use axum_core::response::{Headers, IntoResponse}; -/// Trait for generating responses. -/// -/// Types that implement `IntoResponse` can be returned from handlers. -/// -/// # Implementing `IntoResponse` -/// -/// You generally shouldn't have to implement `IntoResponse` manually, as axum -/// provides implementations for many common types. -/// -/// However it might be necessary if you have a custom error type that you want -/// to return from handlers: -/// -/// ```rust -/// use axum::{ -/// Router, -/// body::{self, BoxBody, Bytes}, -/// routing::get, -/// http::{Response, StatusCode}, -/// response::IntoResponse, -/// }; -/// -/// enum MyError { -/// SomethingWentWrong, -/// SomethingElseWentWrong, -/// } -/// -/// impl IntoResponse for MyError { -/// fn into_response(self) -> Response { -/// let body = match self { -/// MyError::SomethingWentWrong => { -/// body::boxed(body::Full::from("something went wrong")) -/// }, -/// MyError::SomethingElseWentWrong => { -/// body::boxed(body::Full::from("something else went wrong")) -/// }, -/// }; -/// -/// Response::builder() -/// .status(StatusCode::INTERNAL_SERVER_ERROR) -/// .body(body) -/// .unwrap() -/// } -/// } -/// -/// // `Result` can now be returned from handlers -/// let app = Router::new().route("/", get(handler)); -/// -/// async fn handler() -> Result<(), MyError> { -/// Err(MyError::SomethingWentWrong) -/// } -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -/// -/// Or if you have a custom body type you'll also need to implement -/// `IntoResponse` for it: -/// -/// ```rust -/// use axum::{ -/// body::{self, BoxBody}, -/// routing::get, -/// response::IntoResponse, -/// Router, -/// }; -/// use http_body::Body; -/// use http::{Response, HeaderMap}; -/// use bytes::Bytes; -/// use std::{ -/// convert::Infallible, -/// task::{Poll, Context}, -/// pin::Pin, -/// }; -/// -/// struct MyBody; -/// -/// // First implement `Body` for `MyBody`. This could for example use -/// // some custom streaming protocol. -/// impl Body for MyBody { -/// type Data = Bytes; -/// type Error = Infallible; -/// -/// fn poll_data( -/// self: Pin<&mut Self>, -/// cx: &mut Context<'_> -/// ) -> Poll>> { -/// # unimplemented!() -/// // ... -/// } -/// -/// fn poll_trailers( -/// self: Pin<&mut Self>, -/// cx: &mut Context<'_> -/// ) -> Poll, Self::Error>> { -/// # unimplemented!() -/// // ... -/// } -/// } -/// -/// // Now we can implement `IntoResponse` directly for `MyBody` -/// impl IntoResponse for MyBody { -/// fn into_response(self) -> Response { -/// Response::new(body::boxed(self)) -/// } -/// } -/// -/// // We don't need to implement `IntoResponse for Response` as that is -/// // covered by a blanket implementation in axum. -/// -/// // `MyBody` can now be returned from handlers. -/// let app = Router::new().route("/", get(|| async { MyBody })); -/// # async { -/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -pub trait IntoResponse { - /// Create a response. - fn into_response(self) -> Response; -} - -impl IntoResponse for () { - fn into_response(self) -> Response { - Response::new(boxed(Empty::new())) - } -} - -impl IntoResponse for Infallible { - fn into_response(self) -> Response { - match self {} - } -} - -impl IntoResponse for Result -where - T: IntoResponse, - E: IntoResponse, -{ - fn into_response(self) -> Response { - match self { - Ok(value) => value.into_response(), - Err(err) => err.into_response(), - } - } -} - -impl IntoResponse for Response -where - B: http_body::Body + Send + 'static, - B::Error: Into, -{ - fn into_response(self) -> Response { - self.map(boxed) - } -} - -macro_rules! impl_into_response_for_body { - ($body:ty) => { - impl IntoResponse for $body { - fn into_response(self) -> Response { - Response::new(boxed(self)) - } - } - }; -} - -impl_into_response_for_body!(hyper::Body); -impl_into_response_for_body!(Full); -impl_into_response_for_body!(Empty); - -impl IntoResponse for http::response::Parts { - fn into_response(self) -> Response { - Response::from_parts(self, boxed(Empty::new())) - } -} - -impl IntoResponse for http_body::combinators::BoxBody -where - E: Into + 'static, -{ - fn into_response(self) -> Response { - Response::new(boxed(self)) - } -} - -impl IntoResponse for http_body::combinators::UnsyncBoxBody -where - E: Into + 'static, -{ - fn into_response(self) -> Response { - Response::new(boxed(self)) - } -} - -impl IntoResponse for MapData -where - B: http_body::Body + Send + 'static, - F: FnMut(B::Data) -> Bytes + Send + 'static, - B::Error: Into, -{ - fn into_response(self) -> Response { - Response::new(boxed(self)) - } -} - -impl IntoResponse for MapErr -where - B: http_body::Body + Send + 'static, - F: FnMut(B::Error) -> E + Send + 'static, - E: Into, -{ - fn into_response(self) -> Response { - Response::new(boxed(self)) - } -} - -impl IntoResponse for &'static str { - #[inline] - fn into_response(self) -> Response { - Cow::Borrowed(self).into_response() - } -} - -impl IntoResponse for String { - #[inline] - fn into_response(self) -> Response { - Cow::<'static, str>::Owned(self).into_response() - } -} - -impl IntoResponse for Cow<'static, str> { - fn into_response(self) -> Response { - let mut res = Response::new(boxed(Full::from(self))); - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), - ); - res - } -} - -impl IntoResponse for Bytes { - fn into_response(self) -> Response { - let mut res = Response::new(boxed(Full::from(self))); - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), - ); - res - } -} - -impl IntoResponse for &'static [u8] { - fn into_response(self) -> Response { - let mut res = Response::new(boxed(Full::from(self))); - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), - ); - res - } -} - -impl IntoResponse for Vec { - fn into_response(self) -> Response { - let mut res = Response::new(boxed(Full::from(self))); - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), - ); - res - } -} - -impl IntoResponse for Cow<'static, [u8]> { - fn into_response(self) -> Response { - let mut res = Response::new(boxed(Full::from(self))); - res.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static(mime::APPLICATION_OCTET_STREAM.as_ref()), - ); - res - } -} - -impl IntoResponse for StatusCode { - fn into_response(self) -> Response { - Response::builder() - .status(self) - .body(boxed(Empty::new())) - .unwrap() - } -} - -impl IntoResponse for (StatusCode, T) -where - T: IntoResponse, -{ - fn into_response(self) -> Response { - let mut res = self.1.into_response(); - *res.status_mut() = self.0; - res - } -} - -impl IntoResponse for (HeaderMap, T) -where - T: IntoResponse, -{ - fn into_response(self) -> Response { - let mut res = self.1.into_response(); - res.headers_mut().extend(self.0); - res - } -} - -impl IntoResponse for (StatusCode, HeaderMap, T) -where - T: IntoResponse, -{ - fn into_response(self) -> Response { - let mut res = self.2.into_response(); - *res.status_mut() = self.0; - res.headers_mut().extend(self.1); - res - } -} - -impl IntoResponse for HeaderMap { - fn into_response(self) -> Response { - let mut res = Response::new(boxed(Empty::new())); - *res.headers_mut() = self; - res - } -} +#[doc(inline)] +pub use self::{redirect::Redirect, sse::Sse}; /// An HTML response. /// @@ -388,7 +48,11 @@ impl From for Html { #[cfg(test)] mod tests { use super::*; - use http::header::{HeaderMap, HeaderName}; + use http::{ + header::{HeaderMap, HeaderName}, + StatusCode, + }; + use http_body::Empty; #[test] fn test_merge_headers() { diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs index e0dc411a..4e951001 100644 --- a/examples/customize-extractor-error/src/main.rs +++ b/examples/customize-extractor-error/src/main.rs @@ -73,9 +73,6 @@ where JsonRejection::MissingJsonContentType(err) => { (StatusCode::BAD_REQUEST, err.to_string().into()) } - JsonRejection::BodyAlreadyExtracted(err) => { - (StatusCode::INTERNAL_SERVER_ERROR, err.to_string().into()) - } JsonRejection::HeadersAlreadyExtracted(err) => { (StatusCode::INTERNAL_SERVER_ERROR, err.to_string().into()) }