diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index 177da224..5c0d331d 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -5,14 +5,13 @@ mod de; use crate::{ extract::{rejection::*, FromRequest, RequestParts}, - routing::{InvalidUtf8InPathParam, UrlParams}, + routing::url_params::UrlParams, }; use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; use http::StatusCode; use serde::de::DeserializeOwned; use std::{ - borrow::Cow, fmt, ops::{Deref, DerefMut}, }; @@ -162,9 +161,9 @@ where type Rejection = PathRejection; async fn from_request(req: &mut RequestParts) -> Result { - let params = match req.extensions_mut().get::>() { - Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params), - Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => { + let params = match req.extensions_mut().get::() { + Some(UrlParams::Params(params)) => params, + Some(UrlParams::InvalidUtf8InPathParam { key }) => { let err = PathDeserializationError { kind: ErrorKind::InvalidUtf8InPathParam { key: key.as_str().to_owned(), @@ -173,7 +172,6 @@ where let err = FailedToDeserializePathParams(err); return Err(err.into()); } - Some(None) => Cow::Owned(Vec::new()), None => { return Err(MissingPathParams.into()); } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 59c2d6c6..be96a644 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -6,7 +6,7 @@ use crate::{ extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo}, response::{IntoResponse, Redirect, Response}, routing::strip_prefix::StripPrefix, - util::{try_downcast, ByteStr, PercentDecodedByteStr}, + util::try_downcast, BoxError, }; use http::{Request, Uri}; @@ -31,6 +31,7 @@ mod method_routing; mod not_found; mod route; mod strip_prefix; +pub(crate) mod url_params; #[cfg(test)] mod tests; @@ -446,14 +447,7 @@ where panic!("should always have a matched path for a route id"); } - let params = match_ - .params - .iter() - .filter(|(key, _)| !key.starts_with(NEST_TAIL_PARAM)) - .map(|(key, value)| (key.to_owned(), value.to_owned())) - .collect::>(); - - insert_url_params(&mut req, params); + url_params::insert_url_params(req.extensions_mut(), match_.params); let mut route = self .routes @@ -557,49 +551,6 @@ fn with_path(uri: &Uri, new_path: &str) -> Uri { Uri::from_parts(parts).unwrap() } -// we store the potential error here such that users can handle invalid path -// params using `Result, _>`. That wouldn't be possible if we -// returned an error immediately when decoding the param -pub(crate) struct UrlParams( - pub(crate) Result, InvalidUtf8InPathParam>, -); - -fn insert_url_params(req: &mut Request, params: Vec<(String, String)>) { - let params = params - .into_iter() - .map(|(k, v)| { - if let Some(decoded) = PercentDecodedByteStr::new(v) { - Ok((ByteStr::new(k), decoded)) - } else { - Err(InvalidUtf8InPathParam { - key: ByteStr::new(k), - }) - } - }) - .collect::, _>>(); - - if let Some(current) = req.extensions_mut().get_mut::>() { - match params { - Ok(params) => { - let mut current = current.take().unwrap(); - if let Ok(current) = &mut current.0 { - current.extend(params); - } - req.extensions_mut().insert(Some(current)); - } - Err(err) => { - req.extensions_mut().insert(Some(UrlParams(Err(err)))); - } - } - } else { - req.extensions_mut().insert(Some(UrlParams(params))); - } -} - -pub(crate) struct InvalidUtf8InPathParam { - pub(crate) key: ByteStr, -} - /// Wrapper around `matchit::Node` that supports merging two `Node`s. #[derive(Clone, Default)] struct Node { diff --git a/axum/src/routing/url_params.rs b/axum/src/routing/url_params.rs new file mode 100644 index 00000000..71f3ccf2 --- /dev/null +++ b/axum/src/routing/url_params.rs @@ -0,0 +1,45 @@ +use crate::util::{ByteStr, PercentDecodedByteStr}; +use http::Extensions; +use matchit::Params; + +pub(crate) enum UrlParams { + Params(Vec<(ByteStr, PercentDecodedByteStr)>), + InvalidUtf8InPathParam { key: ByteStr }, +} + +pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) { + let current_params = extensions.get_mut(); + + if let Some(UrlParams::InvalidUtf8InPathParam { .. }) = current_params { + // nothing to do here since an error was stored earlier + return; + } + + let params = params + .iter() + .filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM)) + .map(|(key, value)| (key.to_owned(), value.to_owned())) + .map(|(k, v)| { + if let Some(decoded) = PercentDecodedByteStr::new(v) { + Ok((ByteStr::new(k), decoded)) + } else { + Err(ByteStr::new(k)) + } + }) + .collect::, _>>(); + + match (current_params, params) { + (Some(UrlParams::InvalidUtf8InPathParam { .. }), _) => { + unreachable!("we check for this state earlier in this method") + } + (_, Err(invalid_key)) => { + extensions.insert(UrlParams::InvalidUtf8InPathParam { key: invalid_key }); + } + (Some(UrlParams::Params(current)), Ok(params)) => { + current.extend(params); + } + (None, Ok(params)) => { + extensions.insert(UrlParams::Params(params)); + } + } +}