mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-18 23:23:30 +01:00
Refactor storing URL params in extensions (#833)
This commit is contained in:
parent
79b94b9bd6
commit
a438e6b106
3 changed files with 52 additions and 58 deletions
|
@ -5,14 +5,13 @@ mod de;
|
|||
|
||||
use crate::{
|
||||
extract::{rejection::*, FromRequest, RequestParts},
|
||||
routing::{InvalidUtf8InPathParam, UrlParams},
|
||||
routing::url_params::UrlParams,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use axum_core::response::{IntoResponse, Response};
|
||||
use http::StatusCode;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
fmt,
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
|
@ -162,9 +161,9 @@ where
|
|||
type Rejection = PathRejection;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let params = match req.extensions_mut().get::<Option<UrlParams>>() {
|
||||
Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params),
|
||||
Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => {
|
||||
let params = match req.extensions_mut().get::<UrlParams>() {
|
||||
Some(UrlParams::Params(params)) => params,
|
||||
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
|
||||
let err = PathDeserializationError {
|
||||
kind: ErrorKind::InvalidUtf8InPathParam {
|
||||
key: key.as_str().to_owned(),
|
||||
|
@ -173,7 +172,6 @@ where
|
|||
let err = FailedToDeserializePathParams(err);
|
||||
return Err(err.into());
|
||||
}
|
||||
Some(None) => Cow::Owned(Vec::new()),
|
||||
None => {
|
||||
return Err(MissingPathParams.into());
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
|||
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
routing::strip_prefix::StripPrefix,
|
||||
util::{try_downcast, ByteStr, PercentDecodedByteStr},
|
||||
util::try_downcast,
|
||||
BoxError,
|
||||
};
|
||||
use http::{Request, Uri};
|
||||
|
@ -31,6 +31,7 @@ mod method_routing;
|
|||
mod not_found;
|
||||
mod route;
|
||||
mod strip_prefix;
|
||||
pub(crate) mod url_params;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
@ -446,14 +447,7 @@ where
|
|||
panic!("should always have a matched path for a route id");
|
||||
}
|
||||
|
||||
let params = match_
|
||||
.params
|
||||
.iter()
|
||||
.filter(|(key, _)| !key.starts_with(NEST_TAIL_PARAM))
|
||||
.map(|(key, value)| (key.to_owned(), value.to_owned()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
insert_url_params(&mut req, params);
|
||||
url_params::insert_url_params(req.extensions_mut(), match_.params);
|
||||
|
||||
let mut route = self
|
||||
.routes
|
||||
|
@ -557,49 +551,6 @@ fn with_path(uri: &Uri, new_path: &str) -> Uri {
|
|||
Uri::from_parts(parts).unwrap()
|
||||
}
|
||||
|
||||
// we store the potential error here such that users can handle invalid path
|
||||
// params using `Result<Path<T>, _>`. That wouldn't be possible if we
|
||||
// returned an error immediately when decoding the param
|
||||
pub(crate) struct UrlParams(
|
||||
pub(crate) Result<Vec<(ByteStr, PercentDecodedByteStr)>, InvalidUtf8InPathParam>,
|
||||
);
|
||||
|
||||
fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
|
||||
let params = params
|
||||
.into_iter()
|
||||
.map(|(k, v)| {
|
||||
if let Some(decoded) = PercentDecodedByteStr::new(v) {
|
||||
Ok((ByteStr::new(k), decoded))
|
||||
} else {
|
||||
Err(InvalidUtf8InPathParam {
|
||||
key: ByteStr::new(k),
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
|
||||
if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
|
||||
match params {
|
||||
Ok(params) => {
|
||||
let mut current = current.take().unwrap();
|
||||
if let Ok(current) = &mut current.0 {
|
||||
current.extend(params);
|
||||
}
|
||||
req.extensions_mut().insert(Some(current));
|
||||
}
|
||||
Err(err) => {
|
||||
req.extensions_mut().insert(Some(UrlParams(Err(err))));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
req.extensions_mut().insert(Some(UrlParams(params)));
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct InvalidUtf8InPathParam {
|
||||
pub(crate) key: ByteStr,
|
||||
}
|
||||
|
||||
/// Wrapper around `matchit::Node` that supports merging two `Node`s.
|
||||
#[derive(Clone, Default)]
|
||||
struct Node {
|
||||
|
|
45
axum/src/routing/url_params.rs
Normal file
45
axum/src/routing/url_params.rs
Normal file
|
@ -0,0 +1,45 @@
|
|||
use crate::util::{ByteStr, PercentDecodedByteStr};
|
||||
use http::Extensions;
|
||||
use matchit::Params;
|
||||
|
||||
pub(crate) enum UrlParams {
|
||||
Params(Vec<(ByteStr, PercentDecodedByteStr)>),
|
||||
InvalidUtf8InPathParam { key: ByteStr },
|
||||
}
|
||||
|
||||
pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) {
|
||||
let current_params = extensions.get_mut();
|
||||
|
||||
if let Some(UrlParams::InvalidUtf8InPathParam { .. }) = current_params {
|
||||
// nothing to do here since an error was stored earlier
|
||||
return;
|
||||
}
|
||||
|
||||
let params = params
|
||||
.iter()
|
||||
.filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM))
|
||||
.map(|(key, value)| (key.to_owned(), value.to_owned()))
|
||||
.map(|(k, v)| {
|
||||
if let Some(decoded) = PercentDecodedByteStr::new(v) {
|
||||
Ok((ByteStr::new(k), decoded))
|
||||
} else {
|
||||
Err(ByteStr::new(k))
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
|
||||
match (current_params, params) {
|
||||
(Some(UrlParams::InvalidUtf8InPathParam { .. }), _) => {
|
||||
unreachable!("we check for this state earlier in this method")
|
||||
}
|
||||
(_, Err(invalid_key)) => {
|
||||
extensions.insert(UrlParams::InvalidUtf8InPathParam { key: invalid_key });
|
||||
}
|
||||
(Some(UrlParams::Params(current)), Ok(params)) => {
|
||||
current.extend(params);
|
||||
}
|
||||
(None, Ok(params)) => {
|
||||
extensions.insert(UrlParams::Params(params));
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue