From bdefe596484f735338697b4b9d7176690e0d2680 Mon Sep 17 00:00:00 2001 From: zleyyij <75810274+zleyyij@users.noreply.github.com> Date: Sat, 28 Sep 2024 13:40:52 -0600 Subject: [PATCH 01/22] Add multipart/form-data response builders to axum-extra (#2654) --- axum-extra/Cargo.toml | 5 +- axum-extra/src/response/mod.rs | 3 + axum-extra/src/response/multiple.rs | 296 ++++++++++++++++++++++++++++ 3 files changed, 302 insertions(+), 2 deletions(-) create mode 100644 axum-extra/src/response/multiple.rs diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 7c967749..b297df68 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -12,7 +12,7 @@ repository = "https://github.com/tokio-rs/axum" version = "0.9.4" [features] -default = ["tracing"] +default = ["tracing", "multipart"] async-read-body = ["dep:tokio-util", "tokio-util?/io", "dep:tokio"] attachment = ["dep:tracing"] @@ -31,7 +31,7 @@ json-lines = [ "tokio-stream?/io-util", "dep:tokio", ] -multipart = ["dep:multer"] +multipart = ["dep:multer", "dep:fastrand"] protobuf = ["dep:prost"] query = ["dep:serde_html_form"] tracing = ["axum-core/tracing", "axum/tracing"] @@ -56,6 +56,7 @@ tower-service = "0.3" # optional dependencies axum-macros = { path = "../axum-macros", version = "0.4.2", optional = true } cookie = { package = "cookie", version = "0.18.0", features = ["percent-encode"], optional = true } +fastrand = { version = "2.1.0", optional = true } form_urlencoded = { version = "1.1.0", optional = true } headers = { version = "0.4.0", optional = true } multer = { version = "3.0.0", optional = true } diff --git a/axum-extra/src/response/mod.rs b/axum-extra/src/response/mod.rs index d17f7be6..3b4b14c0 100644 --- a/axum-extra/src/response/mod.rs +++ b/axum-extra/src/response/mod.rs @@ -6,6 +6,9 @@ mod erased_json; #[cfg(feature = "attachment")] mod attachment; +#[cfg(feature = "multipart")] +pub mod multiple; + #[cfg(feature = "erased-json")] pub use erased_json::ErasedJson; diff --git a/axum-extra/src/response/multiple.rs b/axum-extra/src/response/multiple.rs new file mode 100644 index 00000000..1fdbd8e7 --- /dev/null +++ b/axum-extra/src/response/multiple.rs @@ -0,0 +1,296 @@ +//! Generate forms to use in responses. + +use axum::response::{IntoResponse, Response}; +use fastrand; +use http::{header, HeaderMap, StatusCode}; +use mime::Mime; + +/// Create multipart forms to be used in API responses. +/// +/// This struct implements [`IntoResponse`], and so it can be returned from a handler. +#[derive(Debug)] +pub struct MultipartForm { + parts: Vec, +} + +impl MultipartForm { + /// Initialize a new multipart form with the provided vector of parts. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// let parts: Vec = vec![Part::text("foo".to_string(), "abc"), Part::text("bar".to_string(), "def")]; + /// let form = MultipartForm::with_parts(parts); + /// ``` + #[deprecated] + pub fn with_parts(parts: Vec) -> Self { + MultipartForm { parts } + } +} + +impl IntoResponse for MultipartForm { + fn into_response(self) -> Response { + // see RFC5758 for details + let boundary = generate_boundary(); + let mut headers = HeaderMap::new(); + let mime_type: Mime = match format!("multipart/form-data; boundary={}", boundary).parse() { + Ok(m) => m, + // Realistically this should never happen unless the boundary generation code + // is modified, and that will be caught by unit tests + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Invalid multipart boundary generated", + ) + .into_response() + } + }; + // The use of unwrap is safe here because mime types are inherently string representable + headers.insert(header::CONTENT_TYPE, mime_type.to_string().parse().unwrap()); + let mut serialized_form: Vec = Vec::new(); + for part in self.parts { + // for each part, the boundary is preceded by two dashes + serialized_form.extend_from_slice(format!("--{}\r\n", boundary).as_bytes()); + serialized_form.extend_from_slice(&part.serialize()); + } + serialized_form.extend_from_slice(format!("--{}--", boundary).as_bytes()); + (headers, serialized_form).into_response() + } +} + +// Valid settings for that header are: "base64", "quoted-printable", "8bit", "7bit", and "binary". +/// A single part of a multipart form as defined by +/// +/// and RFC5758. +#[derive(Debug)] +pub struct Part { + // Every part is expected to contain: + // - a [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition + // header, where `Content-Disposition` is set to `form-data`, with a parameter of `name` that is set to + // the name of the field in the form. In the below example, the name of the field is `user`: + // ``` + // Content-Disposition: form-data; name="user" + // ``` + // If the field contains a file, then the `filename` parameter may be set to the name of the file. + // Handling for non-ascii field names is not done here, support for non-ascii characters may be encoded using + // methodology described in RFC 2047. + // - (optionally) a `Content-Type` header, which if not set, defaults to `text/plain`. + // If the field contains a file, then the file should be identified with that file's MIME type (eg: `image/gif`). + // If the `MIME` type is not known or specified, then the MIME type should be set to `application/octet-stream`. + /// The name of the part in question + name: String, + /// If the part should be treated as a file, the filename that should be attached that part + filename: Option, + /// The `Content-Type` header. While not strictly required, it is always set here + mime_type: Mime, + /// The content/body of the part + contents: Vec, +} + +impl Part { + /// Create a new part with `Content-Type` of `text/plain` with the supplied name and contents. + /// + /// This form will not have a defined file name. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// // create a form with a single part that has a field with a name of "foo", + /// // and a value of "abc" + /// let parts: Vec = vec![Part::text("foo".to_string(), "abc")]; + /// let form = MultipartForm::from_iter(parts); + /// ``` + pub fn text(name: String, contents: &str) -> Self { + Self { + name, + filename: None, + mime_type: mime::TEXT_PLAIN_UTF_8, + contents: contents.as_bytes().to_vec(), + } + } + + /// Create a new part containing a generic file, with a `Content-Type` of `application/octet-stream` + /// using the provided file name, field name, and contents. + /// + /// If the MIME type of the file is known, consider using `Part::raw_part`. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// // create a form with a single part that has a field with a name of "foo", + /// // with a file name of "foo.txt", and with the specified contents + /// let parts: Vec = vec![Part::file("foo", "foo.txt", vec![0x68, 0x68, 0x20, 0x6d, 0x6f, 0x6d])]; + /// let form = MultipartForm::from_iter(parts); + /// ``` + pub fn file(field_name: &str, file_name: &str, contents: Vec) -> Self { + Self { + name: field_name.to_owned(), + filename: Some(file_name.to_owned()), + // If the `MIME` type is not known or specified, then the MIME type should be set to `application/octet-stream`. + // See RFC2388 section 3 for specifics. + mime_type: mime::APPLICATION_OCTET_STREAM, + contents, + } + } + + /// Create a new part with more fine-grained control over the semantics of that part. + /// + /// The caller is assumed to have set a valid MIME type. + /// + /// This function will return an error if the provided MIME type is not valid. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// // create a form with a single part that has a field with a name of "part_name", + /// // with a MIME type of "application/json", and the supplied contents. + /// let parts: Vec = vec![Part::raw_part("part_name", "application/json", vec![0x68, 0x68, 0x20, 0x6d, 0x6f, 0x6d], None).expect("MIME type must be valid")]; + /// let form = MultipartForm::from_iter(parts); + /// ``` + pub fn raw_part( + name: &str, + mime_type: &str, + contents: Vec, + filename: Option<&str>, + ) -> Result { + let mime_type = mime_type.parse().map_err(|_| "Invalid MIME type")?; + Ok(Self { + name: name.to_owned(), + filename: filename.map(|f| f.to_owned()), + mime_type, + contents, + }) + } + + /// Serialize this part into a chunk that can be easily inserted into a larger form + pub(super) fn serialize(&self) -> Vec { + // A part is serialized in this general format: + // // the filename is optional + // Content-Disposition: form-data; name="FIELD_NAME"; filename="FILENAME"\r\n + // // the mime type (not strictly required by the spec, but always sent here) + // Content-Type: mime/type\r\n + // // a blank line, then the contents of the file start + // \r\n + // CONTENTS\r\n + + // Format what we can as a string, then handle the rest at a byte level + let mut serialized_part = format!("Content-Disposition: form-data; name=\"{}\"", self.name); + // specify a filename if one was set + if let Some(filename) = &self.filename { + serialized_part += &format!("; filename=\"{}\"", filename); + } + serialized_part += "\r\n"; + // specify the MIME type + serialized_part += &format!("Content-Type: {}\r\n", self.mime_type); + serialized_part += "\r\n"; + let mut part_bytes = serialized_part.as_bytes().to_vec(); + part_bytes.extend_from_slice(&self.contents); + part_bytes.extend_from_slice(b"\r\n"); + + part_bytes + } +} + +impl FromIterator for MultipartForm { + fn from_iter>(iter: T) -> Self { + Self { + parts: iter.into_iter().collect(), + } + } +} + +/// A boundary is defined as a user defined (arbitrary) value that does not occur in any of the data. +/// +/// Because the specification does not clearly define a methodology for generating boundaries, this implementation +/// follow's Reqwest's, and generates a boundary in the format of `XXXXXXXX-XXXXXXXX-XXXXXXXX-XXXXXXXX` where `XXXXXXXX` +/// is a hexadecimal representation of a pseudo randomly generated u64. +fn generate_boundary() -> String { + let a = fastrand::u64(0..u64::MAX); + let b = fastrand::u64(0..u64::MAX); + let c = fastrand::u64(0..u64::MAX); + let d = fastrand::u64(0..u64::MAX); + format!("{a:016x}-{b:016x}-{c:016x}-{d:016x}") +} + +#[cfg(test)] +mod tests { + use super::{generate_boundary, MultipartForm, Part}; + use axum::{body::Body, http}; + use axum::{routing::get, Router}; + use http::{Request, Response}; + use http_body_util::BodyExt; + use mime::Mime; + use tower::ServiceExt; + + #[tokio::test] + async fn process_form() -> Result<(), Box> { + // create a boilerplate handle that returns a form + async fn handle() -> MultipartForm { + let parts: Vec = vec![ + Part::text("part1".to_owned(), "basictext"), + Part::file( + "part2", + "file.txt", + vec![0x68, 0x69, 0x20, 0x6d, 0x6f, 0x6d], + ), + Part::raw_part("part3", "text/plain", b"rawpart".to_vec(), None).unwrap(), + ]; + MultipartForm::from_iter(parts) + } + + // make a request to that handle + let app = Router::new().route("/", get(handle)); + let response: Response<_> = app + .oneshot(Request::builder().uri("/").body(Body::empty())?) + .await?; + // content_type header + let ct_header = response.headers().get("content-type").unwrap().to_str()?; + let boundary = ct_header.split("boundary=").nth(1).unwrap().to_owned(); + let body: &[u8] = &response.into_body().collect().await?.to_bytes(); + assert_eq!( + std::str::from_utf8(body)?, + &format!( + "--{boundary}\r\n\ + Content-Disposition: form-data; name=\"part1\"\r\n\ + Content-Type: text/plain; charset=utf-8\r\n\ + \r\n\ + basictext\r\n\ + --{boundary}\r\n\ + Content-Disposition: form-data; name=\"part2\"; filename=\"file.txt\"\r\n\ + Content-Type: application/octet-stream\r\n\ + \r\n\ + hi mom\r\n\ + --{boundary}\r\n\ + Content-Disposition: form-data; name=\"part3\"\r\n\ + Content-Type: text/plain\r\n\ + \r\n\ + rawpart\r\n\ + --{boundary}--", + boundary = boundary + ) + ); + + Ok(()) + } + + #[test] + fn valid_boundary_generation() { + for _ in 0..256 { + let boundary = generate_boundary(); + let mime_type: Result = + format!("multipart/form-data; boundary={}", boundary).parse(); + assert!( + mime_type.is_ok(), + "The generated boundary was unable to be parsed into a valid mime type." + ); + } + } +} From 5b6d1caaa71c91a50c824dc60cde2553ae1e0682 Mon Sep 17 00:00:00 2001 From: zleyyij <75810274+zleyyij@users.noreply.github.com> Date: Fri, 4 Oct 2024 10:59:26 -0600 Subject: [PATCH 02/22] axum-extra: Remove stray deprecation in multipart builder (#2957) --- axum-extra/src/response/multiple.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/axum-extra/src/response/multiple.rs b/axum-extra/src/response/multiple.rs index 1fdbd8e7..250dc024 100644 --- a/axum-extra/src/response/multiple.rs +++ b/axum-extra/src/response/multiple.rs @@ -24,7 +24,6 @@ impl MultipartForm { /// let parts: Vec = vec![Part::text("foo".to_string(), "abc"), Part::text("bar".to_string(), "def")]; /// let form = MultipartForm::with_parts(parts); /// ``` - #[deprecated] pub fn with_parts(parts: Vec) -> Self { MultipartForm { parts } } From b71d4fa557fe6dfd00ab0b2f6010de4c40fc99a0 Mon Sep 17 00:00:00 2001 From: Sabrina Jewson Date: Sun, 6 Oct 2024 20:01:10 +0100 Subject: [PATCH 03/22] Add `axum_extra::json!` (#2962) --- axum-extra/CHANGELOG.md | 6 +++ axum-extra/Cargo.toml | 6 ++- axum-extra/src/response/erased_json.rs | 71 ++++++++++++++++++++++++++ axum-extra/src/response/mod.rs | 5 ++ axum/src/json.rs | 5 ++ 5 files changed, 91 insertions(+), 2 deletions(-) diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index d26c506d..066029ae 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog], and this project adheres to [Semantic Versioning]. +# Unreleased + +- **added:** Add `json!` for easy construction of JSON responses ([#2962]) + +[#2962]: https://github.com/tokio-rs/axum/pull/2962 + # 0.9.4 - **added:** The `response::Attachment` type ([#2789]) diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index b297df68..aed676d1 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -20,7 +20,7 @@ cookie = ["dep:cookie"] cookie-private = ["cookie", "cookie?/private"] cookie-signed = ["cookie", "cookie?/signed"] cookie-key-expansion = ["cookie", "cookie?/key-expansion"] -erased-json = ["dep:serde_json"] +erased-json = ["dep:serde_json", "dep:typed-json"] form = ["dep:serde_html_form"] json-deserializer = ["dep:serde_json", "dep:serde_path_to_error"] json-lines = [ @@ -69,9 +69,11 @@ tokio = { version = "1.19", optional = true } tokio-stream = { version = "0.1.9", optional = true } tokio-util = { version = "0.7", optional = true } tracing = { version = "0.1.37", default-features = false, optional = true } +typed-json = { version = "0.1.1", optional = true } [dev-dependencies] -axum = { path = "../axum", version = "0.7.2" } +axum = { path = "../axum", features = ["macros"] } +axum-macros = { path = "../axum-macros", features = ["__private"] } hyper = "1.0.0" reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] } serde = { version = "1.0", features = ["derive"] } diff --git a/axum-extra/src/response/erased_json.rs b/axum-extra/src/response/erased_json.rs index 6e94a267..5088ff35 100644 --- a/axum-extra/src/response/erased_json.rs +++ b/axum-extra/src/response/erased_json.rs @@ -12,6 +12,15 @@ use serde::Serialize; /// This allows returning a borrowing type from a handler, or returning different response /// types as JSON from different branches inside a handler. /// +/// Like [`axum::Json`], +/// if the [`Serialize`] implementation fails +/// or if a map with non-string keys is used, +/// a 500 response will be issued +/// whose body is the error message in UTF-8. +/// +/// This can be constructed using [`new`](ErasedJson::new) +/// or the [`json!`](crate::json) macro. +/// /// # Example /// /// ```rust @@ -72,3 +81,65 @@ impl IntoResponse for ErasedJson { } } } + +/// Construct an [`ErasedJson`] response from a JSON literal. +/// +/// A `Content-Type: application/json` header is automatically added. +/// Any variable or expression implementing [`Serialize`] +/// can be interpolated as a value in the literal. +/// If the [`Serialize`] implementation fails, +/// or if a map with non-string keys is used, +/// a 500 response will be issued +/// whose body is the error message in UTF-8. +/// +/// Internally, +/// this function uses the [`typed_json::json!`] macro, +/// allowing it to perform far fewer allocations +/// than a dynamic macro like [`serde_json::json!`] would – +/// it's equivalent to if you had just written +/// `derive(Serialize)` on a struct. +/// +/// # Examples +/// +/// ``` +/// use axum::{ +/// Router, +/// extract::Path, +/// response::Response, +/// routing::get, +/// }; +/// use axum_extra::response::ErasedJson; +/// +/// async fn get_user(Path(user_id) : Path) -> ErasedJson { +/// let user_name = find_user_name(user_id).await; +/// axum_extra::json!({ "name": user_name }) +/// } +/// +/// async fn find_user_name(user_id: u64) -> String { +/// // ... +/// # unimplemented!() +/// } +/// +/// let app = Router::new().route("/users/{id}", get(get_user)); +/// # let _: Router = app; +/// ``` +/// +/// Trailing commas are allowed in both arrays and objects. +/// +/// ``` +/// let response = axum_extra::json!(["trailing",]); +/// ``` +#[macro_export] +macro_rules! json { + ($($t:tt)*) => { + $crate::response::ErasedJson::new( + $crate::response::__private_erased_json::typed_json::json!($($t)*) + ) + } +} + +/// Not public API. Re-exported as `crate::response::__private_erased_json`. +#[doc(hidden)] +pub mod private { + pub use typed_json; +} diff --git a/axum-extra/src/response/mod.rs b/axum-extra/src/response/mod.rs index 3b4b14c0..04a69ec5 100644 --- a/axum-extra/src/response/mod.rs +++ b/axum-extra/src/response/mod.rs @@ -12,6 +12,11 @@ pub mod multiple; #[cfg(feature = "erased-json")] pub use erased_json::ErasedJson; +/// _not_ public API +#[cfg(feature = "erased-json")] +#[doc(hidden)] +pub use erased_json::private as __private_erased_json; + #[cfg(feature = "json-lines")] #[doc(no_inline)] pub use crate::json_lines::JsonLines; diff --git a/axum/src/json.rs b/axum/src/json.rs index 854ead4e..b1135742 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -55,6 +55,11 @@ use serde::{de::DeserializeOwned, Serialize}; /// When used as a response, it can serialize any type that implements [`serde::Serialize`] to /// `JSON`, and will automatically set `Content-Type: application/json` header. /// +/// If the [`Serialize`] implementation decides to fail +/// or if a map with non-string keys is used, +/// a 500 response will be issued +/// whose body is the error message in UTF-8. +/// /// # Response example /// /// ``` From c2e7fc0b3d337cba663bebaab900ddee14c695e9 Mon Sep 17 00:00:00 2001 From: Sabrina Jewson Date: Sun, 6 Oct 2024 20:09:06 +0100 Subject: [PATCH 04/22] Add `MethodFilter::CONNECT` (#2961) --- axum-extra/CHANGELOG.md | 2 + axum-extra/src/routing/mod.rs | 23 +++++++++ axum/CHANGELOG.md | 7 +++ axum/src/routing/method_filter.rs | 34 +++++++++++-- axum/src/routing/method_routing.rs | 79 +++++++++++++++++++++++++++++- axum/src/routing/mod.rs | 6 +-- 6 files changed, 143 insertions(+), 8 deletions(-) diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 066029ae..25b84c0d 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning]. # Unreleased +- **added:** Add `RouterExt::typed_connect` ([#2961]) - **added:** Add `json!` for easy construction of JSON responses ([#2962]) +[#2961]: https://github.com/tokio-rs/axum/pull/2961 [#2962]: https://github.com/tokio-rs/axum/pull/2962 # 0.9.4 diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index a294c547..5fce9591 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -131,6 +131,19 @@ pub trait RouterExt: sealed::Sealed { T: SecondElementIs

+ 'static, P: TypedPath; + /// Add a typed `CONNECT` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_connect(self, handler: H) -> Self + where + H: axum::handler::Handler, + T: SecondElementIs

+ 'static, + P: TypedPath; + /// Add another route to the router with an additional "trailing slash redirect" route. /// /// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a @@ -255,6 +268,16 @@ where self.route(P::PATH, axum::routing::trace(handler)) } + #[cfg(feature = "typed-routing")] + fn typed_connect(self, handler: H) -> Self + where + H: axum::handler::Handler, + T: SecondElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::connect(handler)) + } + #[track_caller] fn route_with_tsr(mut self, path: &str, method_router: MethodRouter) -> Self where diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index aa9d8067..ed355e82 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +# Unreleased + +- **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]` + and `MethodRouter::connect[_service]` ([#2961]) + +[#2961]: https://github.com/tokio-rs/axum/pull/2961 + # 0.7.7 - **change**: Remove manual tables of content from the documentation, since diff --git a/axum/src/routing/method_filter.rs b/axum/src/routing/method_filter.rs index 1cea4235..040783ec 100644 --- a/axum/src/routing/method_filter.rs +++ b/axum/src/routing/method_filter.rs @@ -9,6 +9,24 @@ use std::{ pub struct MethodFilter(u16); impl MethodFilter { + /// Match `CONNECT` requests. + /// + /// This is useful for implementing HTTP/2's [extended CONNECT method], + /// in which the `:protocol` pseudoheader is read + /// (using [`hyper::ext::Protocol`]) + /// and the connection upgraded to a bidirectional byte stream + /// (using [`hyper::upgrade::on`]). + /// + /// As seen in the [HTTP Upgrade Token Registry], + /// common uses include WebSockets and proxying UDP or IP – + /// though note that when using [`WebSocketUpgrade`] + /// it's more useful to use [`any`](crate::routing::any) + /// as HTTP/1.1 WebSockets need to support `GET`. + /// + /// [extended CONNECT]: https://www.rfc-editor.org/rfc/rfc8441.html#section-4 + /// [HTTP Upgrade Token Registry]: https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml + /// [`WebSocketUpgrade`]: crate::extract::WebSocketUpgrade + pub const CONNECT: Self = Self::from_bits(0b0_0000_0001); /// Match `DELETE` requests. pub const DELETE: Self = Self::from_bits(0b0_0000_0010); /// Match `GET` requests. @@ -71,6 +89,7 @@ impl TryFrom for MethodFilter { fn try_from(m: Method) -> Result { match m { + Method::CONNECT => Ok(MethodFilter::CONNECT), Method::DELETE => Ok(MethodFilter::DELETE), Method::GET => Ok(MethodFilter::GET), Method::HEAD => Ok(MethodFilter::HEAD), @@ -90,6 +109,11 @@ mod tests { #[test] fn from_http_method() { + assert_eq!( + MethodFilter::try_from(Method::CONNECT).unwrap(), + MethodFilter::CONNECT + ); + assert_eq!( MethodFilter::try_from(Method::DELETE).unwrap(), MethodFilter::DELETE @@ -130,9 +154,11 @@ mod tests { MethodFilter::TRACE ); - assert!(MethodFilter::try_from(http::Method::CONNECT) - .unwrap_err() - .to_string() - .contains("CONNECT")); + assert!( + MethodFilter::try_from(http::Method::from_bytes(b"CUSTOM").unwrap()) + .unwrap_err() + .to_string() + .contains("CUSTOM") + ); } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 1eb6075b..5ed4f6a9 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -59,6 +59,19 @@ macro_rules! top_level_service_fn { ); }; + ( + $name:ident, CONNECT + ) => { + top_level_service_fn!( + /// Route `CONNECT` requests to the given service. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`get_service`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -118,6 +131,19 @@ macro_rules! top_level_handler_fn { ); }; + ( + $name:ident, CONNECT + ) => { + top_level_handler_fn!( + /// Route `CONNECT` requests to the given handler. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`get`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -187,6 +213,19 @@ macro_rules! chained_service_fn { ); }; + ( + $name:ident, CONNECT + ) => { + chained_service_fn!( + /// Chain an additional service that will only accept `CONNECT` requests. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`MethodRouter::get_service`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -249,6 +288,19 @@ macro_rules! chained_handler_fn { ); }; + ( + $name:ident, CONNECT + ) => { + chained_handler_fn!( + /// Chain an additional handler that will only accept `CONNECT` requests. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`MethodRouter::get`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -278,6 +330,7 @@ macro_rules! chained_handler_fn { }; } +top_level_service_fn!(connect_service, CONNECT); top_level_service_fn!(delete_service, DELETE); top_level_service_fn!(get_service, GET); top_level_service_fn!(head_service, HEAD); @@ -381,6 +434,7 @@ where .skip_allow_header() } +top_level_handler_fn!(connect, CONNECT); top_level_handler_fn!(delete, DELETE); top_level_handler_fn!(get, GET); top_level_handler_fn!(head, HEAD); @@ -497,6 +551,7 @@ pub struct MethodRouter { post: MethodEndpoint, put: MethodEndpoint, trace: MethodEndpoint, + connect: MethodEndpoint, fallback: Fallback, allow_header: AllowHeader, } @@ -538,6 +593,7 @@ impl fmt::Debug for MethodRouter { .field("post", &self.post) .field("put", &self.put) .field("trace", &self.trace) + .field("connect", &self.connect) .field("fallback", &self.fallback) .field("allow_header", &self.allow_header) .finish() @@ -582,6 +638,7 @@ where ) } + chained_handler_fn!(connect, CONNECT); chained_handler_fn!(delete, DELETE); chained_handler_fn!(get, GET); chained_handler_fn!(head, HEAD); @@ -689,6 +746,7 @@ where post: MethodEndpoint::None, put: MethodEndpoint::None, trace: MethodEndpoint::None, + connect: MethodEndpoint::None, allow_header: AllowHeader::None, fallback: Fallback::Default(fallback), } @@ -705,6 +763,7 @@ where post: self.post.with_state(&state), put: self.put.with_state(&state), trace: self.trace.with_state(&state), + connect: self.connect.with_state(&state), allow_header: self.allow_header, fallback: self.fallback.with_state(state), } @@ -853,9 +912,20 @@ where &["DELETE"], ); + set_endpoint( + "CONNECT", + &mut self.options, + &endpoint, + filter, + MethodFilter::CONNECT, + &mut self.allow_header, + &["CONNECT"], + ); + self } + chained_service_fn!(connect_service, CONNECT); chained_service_fn!(delete_service, DELETE); chained_service_fn!(get_service, GET); chained_service_fn!(head_service, HEAD); @@ -899,6 +969,7 @@ where post: self.post.map(layer_fn.clone()), put: self.put.map(layer_fn.clone()), trace: self.trace.map(layer_fn.clone()), + connect: self.connect.map(layer_fn.clone()), fallback: self.fallback.map(layer_fn), allow_header: self.allow_header, } @@ -923,6 +994,7 @@ where && self.post.is_none() && self.put.is_none() && self.trace.is_none() + && self.connect.is_none() { panic!( "Adding a route_layer before any routes is a no-op. \ @@ -943,7 +1015,8 @@ where self.patch = self.patch.map(layer_fn.clone()); self.post = self.post.map(layer_fn.clone()); self.put = self.put.map(layer_fn.clone()); - self.trace = self.trace.map(layer_fn); + self.trace = self.trace.map(layer_fn.clone()); + self.connect = self.connect.map(layer_fn); self } @@ -984,6 +1057,7 @@ where self.post = merge_inner(path, "POST", self.post, other.post); self.put = merge_inner(path, "PUT", self.put, other.put); self.trace = merge_inner(path, "TRACE", self.trace, other.trace); + self.connect = merge_inner(path, "CONNECT", self.connect, other.connect); self.fallback = self .fallback @@ -1059,6 +1133,7 @@ where post, put, trace, + connect, fallback, allow_header, } = self; @@ -1072,6 +1147,7 @@ where call!(req, method, PUT, put); call!(req, method, DELETE, delete); call!(req, method, TRACE, trace); + call!(req, method, CONNECT, connect); let future = fallback.clone().call_with_state(req, state); @@ -1114,6 +1190,7 @@ impl Clone for MethodRouter { post: self.post.clone(), put: self.put.clone(), trace: self.trace.clone(), + connect: self.connect.clone(), fallback: self.fallback.clone(), allow_header: self.allow_header.clone(), } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index dc6ca815..822be773 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -40,9 +40,9 @@ mod tests; pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; pub use self::method_routing::{ - any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service, - options, options_service, patch, patch_service, post, post_service, put, put_service, trace, - trace_service, MethodRouter, + any, any_service, connect, connect_service, delete, delete_service, get, get_service, head, + head_service, on, on_service, options, options_service, patch, patch_service, post, + post_service, put, put_service, trace, trace_service, MethodRouter, }; macro_rules! panic_on_err { From 0d4e224caeaeeb3b40127c00ffeabd01563d5674 Mon Sep 17 00:00:00 2001 From: novacrazy Date: Wed, 9 Oct 2024 22:16:18 +0200 Subject: [PATCH 05/22] Avoid one state clone This is an extraction of a part of https://github.com/tokio-rs/axum/pull/2865 --- axum/src/routing/method_routing.rs | 41 ++++++++++++++++-------------- axum/src/routing/route.rs | 8 ++++++ axum/src/routing/tests/mod.rs | 2 +- 3 files changed, 31 insertions(+), 20 deletions(-) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 5ed4f6a9..28fa314f 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1097,32 +1097,35 @@ where } pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture { + let method = req.method(); + let is_head = *method == Method::HEAD; + macro_rules! call { ( - $req:expr, - $method:expr, $method_variant:ident, $svc:expr ) => { - if $method == Method::$method_variant { + if *method == Method::$method_variant { match $svc { MethodEndpoint::None => {} MethodEndpoint::Route(route) => { - return RouteFuture::from_future(route.clone().oneshot_inner($req)) - .strip_body($method == Method::HEAD); + return RouteFuture::from_future( + route.clone().oneshot_inner_owned(req), + ) + .strip_body(is_head); } MethodEndpoint::BoxedHandler(handler) => { let route = handler.clone().into_route(state); - return RouteFuture::from_future(route.clone().oneshot_inner($req)) - .strip_body($method == Method::HEAD); + return RouteFuture::from_future( + route.clone().oneshot_inner_owned(req), + ) + .strip_body(is_head); } } } }; } - let method = req.method().clone(); - // written with a pattern match like this to ensure we call all routes let Self { get, @@ -1138,16 +1141,16 @@ where allow_header, } = self; - call!(req, method, HEAD, head); - call!(req, method, HEAD, get); - call!(req, method, GET, get); - call!(req, method, POST, post); - call!(req, method, OPTIONS, options); - call!(req, method, PATCH, patch); - call!(req, method, PUT, put); - call!(req, method, DELETE, delete); - call!(req, method, TRACE, trace); - call!(req, method, CONNECT, connect); + call!(HEAD, head); + call!(HEAD, get); + call!(GET, get); + call!(POST, post); + call!(OPTIONS, options); + call!(PATCH, patch); + call!(PUT, put); + call!(DELETE, delete); + call!(TRACE, trace); + call!(CONNECT, connect); let future = fallback.clone().call_with_state(req, state); diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 242e25f3..25d41daf 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -49,6 +49,14 @@ impl Route { self.0.get_mut().unwrap().clone().oneshot(req) } + /// Variant of [`Route::oneshot_inner`] that takes ownership of the route to avoid cloning. + pub(crate) fn oneshot_inner_owned( + self, + req: Request, + ) -> Oneshot, Request> { + self.0.into_inner().unwrap().oneshot(req) + } + pub(crate) fn layer(self, layer: L) -> Route where L: Layer> + Clone + Send + 'static, diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 144c870d..3dad0728 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -952,7 +952,7 @@ async fn state_isnt_cloned_too_much() { client.get("/").await; - assert_eq!(COUNT.load(Ordering::SeqCst), 4); + assert_eq!(COUNT.load(Ordering::SeqCst), 3); } #[crate::test] From e66be09d5934428a0f53ea5fa88cbd1a7c30e401 Mon Sep 17 00:00:00 2001 From: Yann Simon Date: Wed, 9 Oct 2024 22:57:35 +0200 Subject: [PATCH 06/22] Cover clone in fallback with tests --- axum/src/routing/tests/fallback.rs | 18 +++++++ axum/src/routing/tests/mod.rs | 47 +++-------------- .../test_helpers/counting_cloneable_state.rs | 52 +++++++++++++++++++ axum/src/test_helpers/mod.rs | 2 + 4 files changed, 79 insertions(+), 40 deletions(-) create mode 100644 axum/src/test_helpers/counting_cloneable_state.rs diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index ee116a41..4ff55ae4 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -328,3 +328,21 @@ async fn merge_router_with_fallback_into_empty() { assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } + +#[crate::test] +async fn state_isnt_cloned_too_much_with_fallback() { + let state = CountingCloneableState::new(); + + let app = Router::new() + .fallback(|_: State| async {}) + .with_state(state.clone()); + + let client = TestClient::new(app); + + // ignore clones made during setup + state.setup_done(); + + client.get("/does-not-exist").await; + + assert_eq!(state.count(), 4); +} diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 3dad0728..7e91a977 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -16,6 +16,7 @@ use crate::{ BoxError, Extension, Json, Router, ServiceExt, }; use axum_core::extract::Request; +use counting_cloneable_state::CountingCloneableState; use futures_util::stream::StreamExt; use http::{ header::{ALLOW, CONTENT_LENGTH, HOST}, @@ -27,7 +28,7 @@ use serde_json::json; use std::{ convert::Infallible, future::{ready, IntoFuture, Ready}, - sync::atomic::{AtomicBool, AtomicUsize, Ordering}, + sync::atomic::{AtomicUsize, Ordering}, task::{Context, Poll}, time::Duration, }; @@ -905,54 +906,20 @@ fn test_path_for_nested_route() { #[crate::test] async fn state_isnt_cloned_too_much() { - static SETUP_DONE: AtomicBool = AtomicBool::new(false); - static COUNT: AtomicUsize = AtomicUsize::new(0); - - struct AppState; - - impl Clone for AppState { - fn clone(&self) -> Self { - #[rustversion::since(1.66)] - #[track_caller] - fn count() { - if SETUP_DONE.load(Ordering::SeqCst) { - let bt = std::backtrace::Backtrace::force_capture(); - let bt = bt - .to_string() - .lines() - .filter(|line| line.contains("axum") || line.contains("./src")) - .collect::>() - .join("\n"); - println!("AppState::Clone:\n===============\n{bt}\n"); - COUNT.fetch_add(1, Ordering::SeqCst); - } - } - - #[rustversion::not(since(1.66))] - fn count() { - if SETUP_DONE.load(Ordering::SeqCst) { - COUNT.fetch_add(1, Ordering::SeqCst); - } - } - - count(); - - Self - } - } + let state = CountingCloneableState::new(); let app = Router::new() - .route("/", get(|_: State| async {})) - .with_state(AppState); + .route("/", get(|_: State| async {})) + .with_state(state.clone()); let client = TestClient::new(app); // ignore clones made during setup - SETUP_DONE.store(true, Ordering::SeqCst); + state.setup_done(); client.get("/").await; - assert_eq!(COUNT.load(Ordering::SeqCst), 3); + assert_eq!(state.count(), 3); } #[crate::test] diff --git a/axum/src/test_helpers/counting_cloneable_state.rs b/axum/src/test_helpers/counting_cloneable_state.rs new file mode 100644 index 00000000..762d5ce9 --- /dev/null +++ b/axum/src/test_helpers/counting_cloneable_state.rs @@ -0,0 +1,52 @@ +use std::sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, +}; + +pub(crate) struct CountingCloneableState { + state: Arc, +} + +struct InnerState { + setup_done: AtomicBool, + count: AtomicUsize, +} + +impl CountingCloneableState { + pub(crate) fn new() -> Self { + let inner_state = InnerState { + setup_done: AtomicBool::new(false), + count: AtomicUsize::new(0), + }; + CountingCloneableState { + state: Arc::new(inner_state), + } + } + + pub(crate) fn setup_done(&self) { + self.state.setup_done.store(true, Ordering::SeqCst); + } + + pub(crate) fn count(&self) -> usize { + self.state.count.load(Ordering::SeqCst) + } +} + +impl Clone for CountingCloneableState { + fn clone(&self) -> Self { + let state = self.state.clone(); + if state.setup_done.load(Ordering::SeqCst) { + let bt = std::backtrace::Backtrace::force_capture(); + let bt = bt + .to_string() + .lines() + .filter(|line| line.contains("axum") || line.contains("./src")) + .collect::>() + .join("\n"); + println!("AppState::Clone:\n===============\n{bt}\n"); + state.count.fetch_add(1, Ordering::SeqCst); + } + + CountingCloneableState { state } + } +} diff --git a/axum/src/test_helpers/mod.rs b/axum/src/test_helpers/mod.rs index c6ae1bff..5c29f78d 100644 --- a/axum/src/test_helpers/mod.rs +++ b/axum/src/test_helpers/mod.rs @@ -7,6 +7,8 @@ pub(crate) use self::test_client::*; pub(crate) mod tracing_helpers; +pub(crate) mod counting_cloneable_state; + pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {} From 092719c2175dd1185e534b53a237b32b371d0065 Mon Sep 17 00:00:00 2001 From: novacrazy Date: Thu, 14 Nov 2024 22:34:06 +0100 Subject: [PATCH 07/22] Remove one clone --- axum/src/routing/mod.rs | 8 ++++---- axum/src/routing/tests/fallback.rs | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 822be773..61e7e0f6 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -656,14 +656,14 @@ where } } - fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { + fn call_with_state(self, req: Request, state: S) -> RouteFuture { match self { Fallback::Default(route) | Fallback::Service(route) => { - RouteFuture::from_future(route.oneshot_inner(req)) + RouteFuture::from_future(route.oneshot_inner_owned(req)) } Fallback::BoxedHandler(handler) => { - let mut route = handler.clone().into_route(state); - RouteFuture::from_future(route.oneshot_inner(req)) + let route = handler.clone().into_route(state); + RouteFuture::from_future(route.oneshot_inner_owned(req)) } } } diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 4ff55ae4..02850b19 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -344,5 +344,5 @@ async fn state_isnt_cloned_too_much_with_fallback() { client.get("/does-not-exist").await; - assert_eq!(state.count(), 4); + assert_eq!(state.count(), 3); } From 296dfe1a40b0cebba7a3e88a99cd0829629c3889 Mon Sep 17 00:00:00 2001 From: Yann Simon Date: Thu, 10 Oct 2024 11:34:01 +0200 Subject: [PATCH 08/22] Add test to cover state cloning in layer --- axum/src/routing/tests/mod.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 7e91a977..83b2ed81 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -3,6 +3,7 @@ use crate::{ error_handling::HandleErrorLayer, extract::{self, DefaultBodyLimit, FromRef, Path, State}, handler::{Handler, HandlerWithoutStateExt}, + middleware::{self, Next}, response::{IntoResponse, Response}, routing::{ delete, get, get_service, on, on_service, patch, patch_service, @@ -922,6 +923,26 @@ async fn state_isnt_cloned_too_much() { assert_eq!(state.count(), 3); } +#[crate::test] +async fn state_isnt_cloned_too_much_in_layer() { + async fn layer(State(_): State, req: Request, next: Next) -> Response { + next.run(req).await + } + + let state = CountingCloneableState::new(); + + let app = Router::new().layer(middleware::from_fn_with_state(state.clone(), layer)); + + let client = TestClient::new(app); + + // ignore clones made during setup + state.setup_done(); + + client.get("/").await; + + assert_eq!(state.count(), 4); +} + #[crate::test] async fn logging_rejections() { #[derive(Deserialize, Eq, PartialEq, Debug)] From c417a2814246845d768a51f146ca0d8fe3432a38 Mon Sep 17 00:00:00 2001 From: novacrazy Date: Thu, 10 Oct 2024 11:35:04 +0200 Subject: [PATCH 09/22] Avoid cloning the state in layer --- axum/src/routing/path_router.rs | 2 +- axum/src/routing/route.rs | 6 ++++++ axum/src/routing/tests/mod.rs | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index e132f1dc..fd7d9ff1 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -357,7 +357,7 @@ where Endpoint::MethodRouter(method_router) => { Ok(method_router.call_with_state(req, state)) } - Endpoint::Route(route) => Ok(route.clone().call(req)), + Endpoint::Route(route) => Ok(route.clone().call_owned(req)), } } // explicitly handle all variants in case matchit adds diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 25d41daf..3067e675 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -42,6 +42,12 @@ impl Route { ))) } + /// Variant of [`Route::call`] that takes ownership of the route to avoid cloning. + pub(crate) fn call_owned(self, req: Request) -> RouteFuture { + let req = req.map(Body::new); + RouteFuture::from_future(self.oneshot_inner_owned(req)) + } + pub(crate) fn oneshot_inner( &mut self, req: Request, diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 83b2ed81..c7ae1f70 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -940,7 +940,7 @@ async fn state_isnt_cloned_too_much_in_layer() { client.get("/").await; - assert_eq!(state.count(), 4); + assert_eq!(state.count(), 3); } #[crate::test] From 236781cfdc76f0db20fd4decccde38e894b499d6 Mon Sep 17 00:00:00 2001 From: novacrazy Date: Thu, 10 Oct 2024 10:36:22 +0200 Subject: [PATCH 10/22] Avoid cloning the uri/path --- axum/src/routing/path_router.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index fd7d9ff1..5e317b90 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -331,9 +331,9 @@ where } } - let path = req.uri().path().to_owned(); + let (mut parts, body) = req.into_parts(); - match self.node.at(&path) { + match self.node.at(parts.uri.path()) { Ok(match_) => { let id = *match_.value; @@ -342,17 +342,18 @@ where crate::extract::matched_path::set_matched_path_for_request( id, &self.node.route_id_to_path, - req.extensions_mut(), + &mut parts.extensions, ); } - url_params::insert_url_params(req.extensions_mut(), match_.params); + url_params::insert_url_params(&mut parts.extensions, match_.params); let endpoint = self .routes .get(&id) .expect("no route for id. This is a bug in axum. Please file an issue"); + let req = Request::from_parts(parts, body); match endpoint { Endpoint::MethodRouter(method_router) => { Ok(method_router.call_with_state(req, state)) @@ -366,7 +367,7 @@ where MatchError::NotFound | MatchError::ExtraTrailingSlash | MatchError::MissingTrailingSlash, - ) => Err((req, state)), + ) => Err((Request::from_parts(parts, body), state)), } } From dcb4af68dee5fcd529e569a8da524dbbe35b1825 Mon Sep 17 00:00:00 2001 From: oxalica Date: Fri, 11 Oct 2024 11:54:11 -0400 Subject: [PATCH 11/22] Add `struct NoContent` as a self-described shortcut (#2978) --- axum/CHANGELOG.md | 2 ++ axum/src/response/mod.rs | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index ed355e82..5410ddca 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -9,8 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]` and `MethodRouter::connect[_service]` ([#2961]) +- **added:** Add `NoContent` as a self-described shortcut for `StatusCode::NO_CONTENT` ([#2978]) [#2961]: https://github.com/tokio-rs/axum/pull/2961 +[#2978]: https://github.com/tokio-rs/axum/pull/2978 # 0.7.7 diff --git a/axum/src/response/mod.rs b/axum/src/response/mod.rs index 6cfd9b07..4d9664ca 100644 --- a/axum/src/response/mod.rs +++ b/axum/src/response/mod.rs @@ -1,7 +1,7 @@ #![doc = include_str!("../docs/response.md")] use axum_core::body::Body; -use http::{header, HeaderValue}; +use http::{header, HeaderValue, StatusCode}; mod redirect; @@ -60,6 +60,31 @@ impl From for Html { } } +/// An empty response with 204 No Content status. +/// +/// Due to historical and implementation reasons, the `IntoResponse` implementation of `()` +/// (unit type) returns an empty response with 200 [`StatusCode::OK`] status. +/// If you specifically want a 204 [`StatusCode::NO_CONTENT`] status, you can use either `StatusCode` type +/// directly, or this shortcut struct for self-documentation. +/// +/// ``` +/// use axum::{extract::Path, response::NoContent}; +/// +/// async fn delete_user(Path(user): Path) -> Result { +/// // ...access database... +/// # drop(user); +/// Ok(NoContent) +/// } +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct NoContent; + +impl IntoResponse for NoContent { + fn into_response(self) -> Response { + StatusCode::NO_CONTENT.into_response() + } +} + #[cfg(test)] mod tests { use crate::extract::Extension; @@ -224,4 +249,12 @@ mod tests { .route("/", get(header_array_extension_body)) .route("/", get(header_array_extension_mixed_body)); } + + #[test] + fn no_content() { + assert_eq!( + super::NoContent.into_response().status(), + StatusCode::NO_CONTENT, + ) + } } From a59a82c2af2b44bab01912b6953d8fefceb0daa5 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Fri, 4 Oct 2024 17:00:50 +0000 Subject: [PATCH 12/22] Replace Router::{map_inner, tap_inner_mut} by macros (#2954) --- axum/src/routing/mod.rs | 67 ++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 61e7e0f6..7499e1fc 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -103,6 +103,31 @@ pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_pa pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback"; pub(crate) const FALLBACK_PARAM_PATH: &str = "/*__private__axum_fallback"; +macro_rules! map_inner { + ( $self_:ident, $inner:pat_param => $expr:expr) => { + #[allow(redundant_semicolons)] + { + let $inner = $self_.into_inner(); + Router { + inner: Arc::new($expr), + } + } + }; +} + +macro_rules! tap_inner { + ( $self_:ident, mut $inner:ident => { $($stmt:stmt)* } ) => { + #[allow(redundant_semicolons)] + { + let mut $inner = $self_.into_inner(); + $($stmt)* + Router { + inner: Arc::new($inner), + } + } + }; +} + impl Router where S: Clone + Send + Sync + 'static, @@ -122,26 +147,6 @@ where } } - fn map_inner(self, f: F) -> Router - where - F: FnOnce(RouterInner) -> RouterInner, - { - Router { - inner: Arc::new(f(self.into_inner())), - } - } - - fn tap_inner_mut(self, f: F) -> Self - where - F: FnOnce(&mut RouterInner), - { - let mut inner = self.into_inner(); - f(&mut inner); - Router { - inner: Arc::new(inner), - } - } - fn into_inner(self) -> RouterInner { match Arc::try_unwrap(self.inner) { Ok(inner) => inner, @@ -157,7 +162,7 @@ where #[doc = include_str!("../docs/routing/route.md")] #[track_caller] pub fn route(self, path: &str, method_router: MethodRouter) -> Self { - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { panic_on_err!(this.path_router.route(path, method_router)); }) } @@ -179,7 +184,7 @@ where Err(service) => service, }; - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { panic_on_err!(this.path_router.route_service(path, service)); }) } @@ -198,7 +203,7 @@ where catch_all_fallback: _, } = router.into_inner(); - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { panic_on_err!(this.path_router.nest(path, path_router)); if !default_fallback { @@ -215,7 +220,7 @@ where T::Response: IntoResponse, T::Future: Send + 'static, { - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { panic_on_err!(this.path_router.nest_service(path, service)); }) } @@ -237,7 +242,7 @@ where catch_all_fallback, } = other.into_inner(); - self.map_inner(|mut this| { + map_inner!(self, mut this => { panic_on_err!(this.path_router.merge(path_router)); match (this.default_fallback, default_fallback) { @@ -281,7 +286,7 @@ where >::Error: Into + 'static, >::Future: Send + 'static, { - self.map_inner(|this| RouterInner { + map_inner!(self, this => RouterInner { path_router: this.path_router.layer(layer.clone()), fallback_router: this.fallback_router.layer(layer.clone()), default_fallback: this.default_fallback, @@ -299,7 +304,7 @@ where >::Error: Into + 'static, >::Future: Send + 'static, { - self.map_inner(|this| RouterInner { + map_inner!(self, this => RouterInner { path_router: this.path_router.route_layer(layer), fallback_router: this.fallback_router, default_fallback: this.default_fallback, @@ -319,7 +324,7 @@ where H: Handler, T: 'static, { - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { this.catch_all_fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); }) @@ -336,14 +341,14 @@ where T::Future: Send + 'static, { let route = Route::new(service); - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { this.catch_all_fallback = Fallback::Service(route.clone()); }) .fallback_endpoint(Endpoint::Route(route)) } fn fallback_endpoint(self, endpoint: Endpoint) -> Self { - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { this.fallback_router.set_fallback(endpoint); this.default_fallback = false; }) @@ -351,7 +356,7 @@ where #[doc = include_str!("../docs/routing/with_state.md")] pub fn with_state(self, state: S) -> Router { - self.map_inner(|this| RouterInner { + map_inner!(self, this => RouterInner { path_router: this.path_router.with_state(state.clone()), fallback_router: this.fallback_router.with_state(state.clone()), default_fallback: this.default_fallback, From c5a3c66a2790419c21cf8c935dbca9767151b6b5 Mon Sep 17 00:00:00 2001 From: Leon Lux Date: Sat, 12 Oct 2024 11:36:56 +0200 Subject: [PATCH 13/22] Add method_not_allowed_fallback to router (#2903) Co-authored-by: Jonas Platte --- .../routing/method_not_allowed_fallback.md | 38 ++++++++++++++++ axum/src/routing/method_routing.rs | 13 ++++++ axum/src/routing/mod.rs | 12 +++++ axum/src/routing/path_router.rs | 17 ++++++- axum/src/routing/tests/fallback.rs | 45 +++++++++++++++++++ 5 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 axum/src/docs/routing/method_not_allowed_fallback.md diff --git a/axum/src/docs/routing/method_not_allowed_fallback.md b/axum/src/docs/routing/method_not_allowed_fallback.md new file mode 100644 index 00000000..22905cd9 --- /dev/null +++ b/axum/src/docs/routing/method_not_allowed_fallback.md @@ -0,0 +1,38 @@ +Add a fallback [`Handler`] for the case where a route exists, but the method of the request is not supported. + +Sets a fallback on all previously registered [`MethodRouter`]s, +to be called when no matching method handler is set. + +```rust,no_run +use axum::{response::IntoResponse, routing::get, Router}; + +async fn hello_world() -> impl IntoResponse { + "Hello, world!\n" +} + +async fn default_fallback() -> impl IntoResponse { + "Default fallback\n" +} + +async fn handle_405() -> impl IntoResponse { + "Method not allowed fallback" +} + +#[tokio::main] +async fn main() { + let router = Router::new() + .route("/", get(hello_world)) + .fallback(default_fallback) + .method_not_allowed_fallback(handle_405); + + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + + axum::serve(listener, router).await.unwrap(); +} +``` + +The fallback only applies if there is a `MethodRouter` registered for a given path, +but the method used in the request is not specified. In the example, a `GET` on +`http://localhost:3000` causes the `hello_world` handler to react, while issuing a +`POST` triggers `handle_405`. Calling an entirely different route, like `http://localhost:3000/hello` +causes `default_fallback` to run. diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 28fa314f..ecc1dae0 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -658,6 +658,19 @@ where self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler)); self } + + /// Add a fallback [`Handler`] if no custom one has been provided. + pub(crate) fn default_fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + S: Send + Sync + 'static, + { + match self.fallback { + Fallback::Default(_) => self.fallback(handler), + _ => self, + } + } } impl MethodRouter<(), Infallible> { diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 7499e1fc..57a3d430 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -347,6 +347,18 @@ where .fallback_endpoint(Endpoint::Route(route)) } + #[doc = include_str!("../docs/routing/method_not_allowed_fallback.md")] + pub fn method_not_allowed_fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + tap_inner!(self, mut this => { + this.path_router + .method_not_allowed_fallback(handler.clone()) + }) + } + fn fallback_endpoint(self, endpoint: Endpoint) -> Self { tap_inner!(self, mut this => { this.fallback_router.set_fallback(endpoint); diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index 5e317b90..400ce32d 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -1,4 +1,7 @@ -use crate::extract::{nested_path::SetNestedPath, Request}; +use crate::{ + extract::{nested_path::SetNestedPath, Request}, + handler::Handler, +}; use axum_core::response::IntoResponse; use matchit::MatchError; use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc}; @@ -79,6 +82,18 @@ where Ok(()) } + pub(super) fn method_not_allowed_fallback(&mut self, handler: H) + where + H: Handler, + T: 'static, + { + for (_, endpoint) in self.routes.iter_mut() { + if let Endpoint::MethodRouter(rt) = endpoint { + *rt = rt.clone().default_fallback(handler.clone()); + } + } + } + pub(super) fn route_service( &mut self, path: &str, diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 02850b19..9dd1c6c2 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -329,6 +329,51 @@ async fn merge_router_with_fallback_into_empty() { assert_eq!(res.text().await, "outer"); } +#[crate::test] +async fn mna_fallback_with_existing_fallback() { + let app = Router::new() + .route( + "/", + get(|| async { "test" }).fallback(|| async { "index fallback" }), + ) + .route("/path", get(|| async { "path" })) + .method_not_allowed_fallback(|| async { "method not allowed fallback" }); + + let client = TestClient::new(app); + let index_fallback = client.post("/").await; + let method_not_allowed_fallback = client.post("/path").await; + + assert_eq!(index_fallback.text().await, "index fallback"); + assert_eq!( + method_not_allowed_fallback.text().await, + "method not allowed fallback" + ); +} + +#[crate::test] +async fn mna_fallback_with_state() { + let app = Router::new() + .route("/", get(|| async { "index" })) + .method_not_allowed_fallback(|State(state): State<&'static str>| async move { state }) + .with_state("state"); + + let client = TestClient::new(app); + let res = client.post("/").await; + assert_eq!(res.text().await, "state"); +} + +#[crate::test] +async fn mna_fallback_with_unused_state() { + let app = Router::new() + .route("/", get(|| async { "index" })) + .with_state(()) + .method_not_allowed_fallback(|| async move { "bla" }); + + let client = TestClient::new(app); + let res = client.post("/").await; + assert_eq!(res.text().await, "bla"); +} + #[crate::test] async fn state_isnt_cloned_too_much_with_fallback() { let state = CountingCloneableState::new(); From a7b7d56cfb02c8c16e0b60329ee302c9c3c3e2d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Ml=C3=A1dek?= Date: Mon, 14 Oct 2024 10:31:47 +0200 Subject: [PATCH 14/22] Update changelog (#2985) --- axum/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 5410ddca..a5d76c31 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,10 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **added:** Add `method_not_allowed_fallback` to set a fallback when a path matches but there is no handler for the given HTTP method ([#2903]) - **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]` and `MethodRouter::connect[_service]` ([#2961]) - **added:** Add `NoContent` as a self-described shortcut for `StatusCode::NO_CONTENT` ([#2978]) +[#2903]: https://github.com/tokio-rs/axum/pull/2903 [#2961]: https://github.com/tokio-rs/axum/pull/2961 [#2978]: https://github.com/tokio-rs/axum/pull/2978 From 8d7eada03416a527b709651c66e3dcf7ed96ad49 Mon Sep 17 00:00:00 2001 From: Flamenco Date: Wed, 9 Oct 2024 06:22:06 -0400 Subject: [PATCH 15/22] Update middleware.md (#2967) --- axum/src/docs/middleware.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axum/src/docs/middleware.md b/axum/src/docs/middleware.md index 3a902372..bef5a8b5 100644 --- a/axum/src/docs/middleware.md +++ b/axum/src/docs/middleware.md @@ -190,7 +190,7 @@ You should use these when ## `tower::Service` and `Pin>` -For maximum control (and a more low level API) you can write you own middleware +For maximum control (and a more low level API) you can write your own middleware by implementing [`tower::Service`]: Use [`tower::Service`] with `Pin>` to write your middleware when: From eb6bea38d018aff4a18c40e0fc865cc02f212d99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Ml=C3=A1dek?= Date: Thu, 17 Oct 2024 12:27:47 +0200 Subject: [PATCH 16/22] chore: fix new clippy lint (#2994) --- axum-extra/src/extract/cookie/private.rs | 2 +- axum-extra/src/extract/cookie/signed.rs | 2 +- axum/src/extract/multipart.rs | 2 +- axum/src/routing/mod.rs | 4 ++-- axum/src/test_helpers/tracing_helpers.rs | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index 911b0ef2..4b58dabe 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -291,7 +291,7 @@ struct PrivateCookieJarIter<'a, K> { iter: cookie::Iter<'a>, } -impl<'a, K> Iterator for PrivateCookieJarIter<'a, K> { +impl Iterator for PrivateCookieJarIter<'_, K> { type Item = Cookie<'static>; fn next(&mut self) -> Option { diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index b65df79f..9d0590a8 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -309,7 +309,7 @@ struct SignedCookieJarIter<'a, K> { iter: cookie::Iter<'a>, } -impl<'a, K> Iterator for SignedCookieJarIter<'a, K> { +impl Iterator for SignedCookieJarIter<'_, K> { type Item = Cookie<'static>; fn next(&mut self) -> Option { diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 7a303a47..7441a798 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -109,7 +109,7 @@ pub struct Field<'a> { _multipart: &'a mut Multipart, } -impl<'a> Stream for Field<'a> { +impl Stream for Field<'_> { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 57a3d430..4ea49db0 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -551,7 +551,7 @@ pub struct RouterAsService<'a, B, S = ()> { _marker: PhantomData, } -impl<'a, B> Service> for RouterAsService<'a, B, ()> +impl Service> for RouterAsService<'_, B, ()> where B: HttpBody + Send + 'static, B::Error: Into, @@ -571,7 +571,7 @@ where } } -impl<'a, B, S> fmt::Debug for RouterAsService<'a, B, S> +impl fmt::Debug for RouterAsService<'_, B, S> where S: fmt::Debug, { diff --git a/axum/src/test_helpers/tracing_helpers.rs b/axum/src/test_helpers/tracing_helpers.rs index 2240717e..667c4994 100644 --- a/axum/src/test_helpers/tracing_helpers.rs +++ b/axum/src/test_helpers/tracing_helpers.rs @@ -73,7 +73,7 @@ impl<'a> MakeWriter<'a> for TestMakeWriter { struct Writer<'a>(&'a TestMakeWriter); -impl<'a> io::Write for Writer<'a> { +impl io::Write for Writer<'_> { fn write(&mut self, buf: &[u8]) -> io::Result { match &mut *self.0.write.lock().unwrap() { Some(vec) => { From 43814c174f5592a6dbad6d4cb8171af2ff073647 Mon Sep 17 00:00:00 2001 From: Erin <149203474+erin-desu@users.noreply.github.com> Date: Thu, 17 Oct 2024 16:43:14 +0200 Subject: [PATCH 17/22] Fix TSR redirecting to top-level inside nested Router (#2993) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: David Mládek --- axum-extra/Cargo.toml | 2 +- axum-extra/src/routing/mod.rs | 20 ++++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index aed676d1..4b02d023 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -39,7 +39,7 @@ typed-header = ["dep:headers"] typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"] [dependencies] -axum = { path = "../axum", version = "0.7.7", default-features = false } +axum = { path = "../axum", version = "0.7.7", default-features = false, features = ["original-uri"] } axum-core = { path = "../axum-core", version = "0.4.5" } bytes = "1.1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 5fce9591..445c3a2b 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -1,7 +1,7 @@ //! Additional types for defining routes. use axum::{ - extract::Request, + extract::{OriginalUri, Request}, response::{IntoResponse, Redirect, Response}, routing::{any, MethodRouter}, Router, @@ -313,7 +313,7 @@ fn add_tsr_redirect_route(router: Router, path: &str) -> Router where S: Clone + Send + Sync + 'static, { - async fn redirect_handler(uri: Uri) -> Response { + async fn redirect_handler(OriginalUri(uri): OriginalUri) -> Response { let new_uri = map_path(uri, |path| { path.strip_suffix('/') .map(Cow::Borrowed) @@ -432,6 +432,22 @@ mod tests { assert_eq!(res.headers()["location"], "/foo?a=a"); } + #[tokio::test] + async fn tsr_works_in_nested_router() { + let app = Router::new().nest( + "/neko", + Router::new().route_with_tsr("/nyan/", get(|| async {})), + ); + + let client = TestClient::new(app); + let res = client.get("/neko/nyan/").await; + assert_eq!(res.status(), StatusCode::OK); + + let res = client.get("/neko/nyan").await; + assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); + assert_eq!(res.headers()["location"], "/neko/nyan/"); + } + #[test] #[should_panic = "Cannot add a trailing slash redirect route for `/`"] fn tsr_at_root() { From 185804398f6d784eeb2eee5c9e522181b2298f84 Mon Sep 17 00:00:00 2001 From: Jan <59206115+Threated@users.noreply.github.com> Date: Thu, 17 Oct 2024 16:49:39 +0200 Subject: [PATCH 18/22] fix(sse): skip sse incompatible chars of `serde_json::RawValue` (#2992) --- axum/CHANGELOG.md | 2 ++ axum/Cargo.toml | 2 +- axum/src/response/sse.rs | 32 +++++++++++++++++++++++++++++++- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index a5d76c31..69a42100 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **fixed:** Skip SSE incompatible chars of `serde_json::RawValue` in `Event::json_data` ([#2992]) - **added:** Add `method_not_allowed_fallback` to set a fallback when a path matches but there is no handler for the given HTTP method ([#2903]) - **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]` and `MethodRouter::connect[_service]` ([#2961]) @@ -15,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#2903]: https://github.com/tokio-rs/axum/pull/2903 [#2961]: https://github.com/tokio-rs/axum/pull/2961 [#2978]: https://github.com/tokio-rs/axum/pull/2978 +[#2992]: https://github.com/tokio-rs/axum/pull/2992 # 0.7.7 diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 263d3a96..6b15c455 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -117,7 +117,7 @@ quickcheck = "1.0" quickcheck_macros = "1.0" reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { version = "1.0", features = ["raw_value"] } time = { version = "0.3", features = ["serde-human-readable"] } tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] } tokio-stream = "0.1" diff --git a/axum/src/response/sse.rs b/axum/src/response/sse.rs index e77b8c78..b414f057 100644 --- a/axum/src/response/sse.rs +++ b/axum/src/response/sse.rs @@ -208,12 +208,29 @@ impl Event { where T: serde::Serialize, { + struct IgnoreNewLines<'a>(bytes::buf::Writer<&'a mut BytesMut>); + impl std::io::Write for IgnoreNewLines<'_> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut last_split = 0; + for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) { + self.0.write_all(&buf[last_split..delimiter])?; + last_split = delimiter + 1; + } + self.0.write_all(&buf[last_split..])?; + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.0.flush() + } + } if self.flags.contains(EventFlags::HAS_DATA) { panic!("Called `EventBuilder::json_data` multiple times"); } self.buffer.extend_from_slice(b"data: "); - serde_json::to_writer((&mut self.buffer).writer(), &data).map_err(axum_core::Error::new)?; + serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data) + .map_err(axum_core::Error::new)?; self.buffer.put_u8(b'\n'); self.flags.insert(EventFlags::HAS_DATA); @@ -515,6 +532,7 @@ mod tests { use super::*; use crate::{routing::get, test_helpers::*, Router}; use futures_util::stream; + use serde_json::value::RawValue; use std::{collections::HashMap, convert::Infallible}; use tokio_stream::StreamExt as _; @@ -527,6 +545,18 @@ mod tests { assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n"); } + #[test] + fn valid_json_raw_value_chars_stripped() { + let json_string = "{\r\"foo\": \n\r\r \"bar\\n\"\n}"; + let json_raw_value_event = Event::default() + .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap()) + .unwrap(); + assert_eq!( + &*json_raw_value_event.finalize(), + format!("data: {}\n\n", json_string.replace(['\n', '\r'], "")).as_bytes() + ); + } + #[crate::test] async fn basic() { let app = Router::new().route( From b30cdcfbead0be1048258cbcaf13663e5fcf9f03 Mon Sep 17 00:00:00 2001 From: Benjamin Sparks Date: Thu, 17 Oct 2024 16:47:51 +0200 Subject: [PATCH 19/22] Remove unneeded macro usage (#2995) Co-authored-by: Benjamin Sparks --- axum/src/handler/mod.rs | 2 -- axum/src/middleware/from_fn.rs | 2 -- axum/src/middleware/map_response.rs | 2 -- axum/src/routing/route.rs | 2 -- examples/serve-with-hyper/src/main.rs | 2 -- examples/unix-domain-socket/src/main.rs | 2 -- 6 files changed, 12 deletions(-) diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 04c66c4a..783e02e3 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -328,8 +328,6 @@ where ) -> _, > = svc.oneshot(req).map(|result| match result { Ok(res) => res.into_response(), - - #[allow(unreachable_patterns)] Err(err) => match err {}, }); diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index f0f47611..3ed7a959 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -341,8 +341,6 @@ impl Next { pub async fn run(mut self, req: Request) -> Response { match self.inner.call(req).await { Ok(res) => res, - - #[allow(unreachable_patterns)] Err(err) => match err {}, } } diff --git a/axum/src/middleware/map_response.rs b/axum/src/middleware/map_response.rs index e4c1c397..2510cdc2 100644 --- a/axum/src/middleware/map_response.rs +++ b/axum/src/middleware/map_response.rs @@ -278,8 +278,6 @@ macro_rules! impl_service { Ok(res) => { f($($ty,)* res).await.into_response() } - - #[allow(unreachable_patterns)] Err(err) => match err {} } }); diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 3067e675..8fc9a1d7 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -250,8 +250,6 @@ impl Future for InfallibleRouteFuture { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match futures_util::ready!(self.project().future.poll(cx)) { Ok(response) => Poll::Ready(response), - - #[allow(unreachable_patterns)] Err(err) => match err {}, } } diff --git a/examples/serve-with-hyper/src/main.rs b/examples/serve-with-hyper/src/main.rs index 9da67fc1..8bad2acd 100644 --- a/examples/serve-with-hyper/src/main.rs +++ b/examples/serve-with-hyper/src/main.rs @@ -11,8 +11,6 @@ //! //! [hyper-util]: https://crates.io/crates/hyper-util -#![allow(unreachable_patterns)] - use std::convert::Infallible; use std::net::SocketAddr; diff --git a/examples/unix-domain-socket/src/main.rs b/examples/unix-domain-socket/src/main.rs index fbb4c3b0..697f31a5 100644 --- a/examples/unix-domain-socket/src/main.rs +++ b/examples/unix-domain-socket/src/main.rs @@ -3,8 +3,6 @@ //! ```not_rust //! cargo run -p example-unix-domain-socket //! ``` -#![allow(unreachable_patterns)] - #[cfg(unix)] #[tokio::main] async fn main() { From 8bc326cc3deefda4604e9cc780c24bcd5613723e Mon Sep 17 00:00:00 2001 From: Leon Lux Date: Mon, 21 Oct 2024 21:47:20 +0200 Subject: [PATCH 20/22] Improve docs regarding state and extensions (#2991) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: David Mládek --- axum/src/docs/routing/with_state.md | 10 ++-------- axum/src/extract/state.rs | 4 ++++ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/axum/src/docs/routing/with_state.md b/axum/src/docs/routing/with_state.md index 973a87e0..197741cf 100644 --- a/axum/src/docs/routing/with_state.md +++ b/axum/src/docs/routing/with_state.md @@ -1,4 +1,5 @@ -Provide the state for the router. +Provide the state for the router. State passed to this method is global and will be used +for all requests this router receives. That means it is not suitable for holding state derived from a request, such as authorization data extracted in a middleware. Use [`Extension`] instead for such data. ```rust use axum::{Router, routing::get, extract::State}; @@ -94,13 +95,6 @@ axum::serve(listener, routes).await.unwrap(); # }; ``` -# State is global within the router - -The state passed to this method will be used for all requests this router -receives. That means it is not suitable for holding state derived from a -request, such as authorization data extracted in a middleware. Use [`Extension`] -instead for such data. - # What `S` in `Router` means `Router` means a router that is _missing_ a state of type `S` to be able to diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index fb401c00..82e1e6e9 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -11,7 +11,11 @@ use std::{ /// See ["Accessing state in middleware"][state-from-middleware] for how to /// access state in middleware. /// +/// State is global and used in every request a router with state receives. +/// For accessing data derived from requests, such as authorization data, see [`Extension`]. +/// /// [state-from-middleware]: crate::middleware#accessing-state-in-middleware +/// [`Extension`]: crate::Extension /// /// # With `Router` /// From da4580247ab71ff97d3604344cff1b19f40db416 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Thu, 14 Nov 2024 21:12:16 +0100 Subject: [PATCH 21/22] Some documentation fixes (#3027) --- axum-core/Cargo.toml | 1 - axum-core/src/extract/rejection.rs | 2 +- axum-extra/Cargo.toml | 1 - axum-extra/src/extract/multipart.rs | 2 +- axum-extra/src/lib.rs | 2 +- axum-macros/Cargo.toml | 1 - axum/Cargo.toml | 10 +++++++--- 7 files changed, 10 insertions(+), 9 deletions(-) diff --git a/axum-core/Cargo.toml b/axum-core/Cargo.toml index f7c31668..7ba2e5f2 100644 --- a/axum-core/Cargo.toml +++ b/axum-core/Cargo.toml @@ -57,4 +57,3 @@ allowed = [ [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] diff --git a/axum-core/src/extract/rejection.rs b/axum-core/src/extract/rejection.rs index 34b8115b..c5c3b1db 100644 --- a/axum-core/src/extract/rejection.rs +++ b/axum-core/src/extract/rejection.rs @@ -42,7 +42,7 @@ define_rejection! { #[body = "Failed to buffer the request body"] /// Encountered some other error when buffering the body. /// - /// This can _only_ happen when you're using [`tower_http::limit::RequestBodyLimitLayer`] or + /// This can _only_ happen when you're using [`tower_http::limit::RequestBodyLimitLayer`] or /// otherwise wrapping request bodies in [`http_body_util::Limited`]. pub struct LengthLimitError(Error); } diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 4b02d023..872febdf 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -84,7 +84,6 @@ tower-http = { version = "0.6.0", features = ["map-response-body", "timeout"] } [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] [package.metadata.cargo-public-api-crates] allowed = [ diff --git a/axum-extra/src/extract/multipart.rs b/axum-extra/src/extract/multipart.rs index 70f04866..6101e64f 100644 --- a/axum-extra/src/extract/multipart.rs +++ b/axum-extra/src/extract/multipart.rs @@ -75,7 +75,7 @@ use std::{ /// to keep `Field`s around from previous loop iterations. That will minimize the risk of runtime /// errors. /// -/// # Differences between this and `axum::extract::Multipart` +/// # Differences between this and `axum::extract::Multipart` /// /// `axum::extract::Multipart` uses lifetimes to enforce field exclusivity at compile time, however /// that leads to significant usability issues such as `Field` not being `'static`. diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index 02dd6e69..9aa3ddf7 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -23,7 +23,7 @@ //! `query` | Enables the `Query` extractor | No //! `tracing` | Log rejections from built-in extractors | Yes //! `typed-routing` | Enables the `TypedPath` routing utilities | No -//! `typed-header` | Enables the `TypedHeader` extractor and response | No +//! `typed-header` | Enables the `TypedHeader` extractor and response | No //! //! [`axum`]: https://crates.io/crates/axum diff --git a/axum-macros/Cargo.toml b/axum-macros/Cargo.toml index d24af7e5..8dcacfed 100644 --- a/axum-macros/Cargo.toml +++ b/axum-macros/Cargo.toml @@ -43,4 +43,3 @@ allowed = [] [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 6b15c455..5f2e0700 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -37,8 +37,13 @@ tower-log = ["tower/log"] tracing = ["dep:tracing", "axum-core/tracing"] ws = ["dep:hyper", "tokio", "dep:tokio-tungstenite", "dep:sha1", "dep:base64"] -# Required for intra-doc links to resolve correctly -__private_docs = ["tower/full", "dep:tower-http"] +__private_docs = [ + # We re-export some docs from axum-core via #[doc(inline)], + # but they need the same sort of treatment as below to be complete + "axum-core/__private_docs", + # Enables upstream things linked to in docs + "tower/full", "dep:tower-http", +] [dependencies] async-trait = "0.1.67" @@ -128,7 +133,6 @@ uuid = { version = "1.0", features = ["serde", "v4"] } [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] [dev-dependencies.tower] package = "tower" From feee742ca14bcf68b826f304e39f92db5f7d76e5 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Thu, 14 Nov 2024 23:13:41 +0100 Subject: [PATCH 22/22] Bump versions --- axum-extra/CHANGELOG.md | 2 +- axum-extra/Cargo.toml | 4 ++-- axum/CHANGELOG.md | 2 +- axum/Cargo.toml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 25b84c0d..34abb208 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog], and this project adheres to [Semantic Versioning]. -# Unreleased +# 0.9.5 - **added:** Add `RouterExt::typed_connect` ([#2961]) - **added:** Add `json!` for easy construction of JSON responses ([#2962]) diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 872febdf..266ed3b8 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -9,7 +9,7 @@ license = "MIT" name = "axum-extra" readme = "README.md" repository = "https://github.com/tokio-rs/axum" -version = "0.9.4" +version = "0.9.5" [features] default = ["tracing", "multipart"] @@ -39,7 +39,7 @@ typed-header = ["dep:headers"] typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"] [dependencies] -axum = { path = "../axum", version = "0.7.7", default-features = false, features = ["original-uri"] } +axum = { path = "../axum", version = "0.7.8", default-features = false, features = ["original-uri"] } axum-core = { path = "../axum-core", version = "0.4.5" } bytes = "1.1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 69a42100..6eb80d6d 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -# Unreleased +# 0.7.8 - **fixed:** Skip SSE incompatible chars of `serde_json::RawValue` in `Event::json_data` ([#2992]) - **added:** Add `method_not_allowed_fallback` to set a fallback when a path matches but there is no handler for the given HTTP method ([#2903]) diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 5f2e0700..c47a1ead 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "axum" -version = "0.7.7" +version = "0.7.8" categories = ["asynchronous", "network-programming", "web-programming::http-server"] description = "Web framework that focuses on ergonomics and modularity" edition = "2021"