mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-05 18:27:07 +01:00
Refactor storing URL params in extensions (#833)
This commit is contained in:
parent
79b94b9bd6
commit
a438e6b106
3 changed files with 52 additions and 58 deletions
|
@ -5,14 +5,13 @@ mod de;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
extract::{rejection::*, FromRequest, RequestParts},
|
extract::{rejection::*, FromRequest, RequestParts},
|
||||||
routing::{InvalidUtf8InPathParam, UrlParams},
|
routing::url_params::UrlParams,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum_core::response::{IntoResponse, Response};
|
use axum_core::response::{IntoResponse, Response};
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use std::{
|
use std::{
|
||||||
borrow::Cow,
|
|
||||||
fmt,
|
fmt,
|
||||||
ops::{Deref, DerefMut},
|
ops::{Deref, DerefMut},
|
||||||
};
|
};
|
||||||
|
@ -162,9 +161,9 @@ where
|
||||||
type Rejection = PathRejection;
|
type Rejection = PathRejection;
|
||||||
|
|
||||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||||
let params = match req.extensions_mut().get::<Option<UrlParams>>() {
|
let params = match req.extensions_mut().get::<UrlParams>() {
|
||||||
Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params),
|
Some(UrlParams::Params(params)) => params,
|
||||||
Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => {
|
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
|
||||||
let err = PathDeserializationError {
|
let err = PathDeserializationError {
|
||||||
kind: ErrorKind::InvalidUtf8InPathParam {
|
kind: ErrorKind::InvalidUtf8InPathParam {
|
||||||
key: key.as_str().to_owned(),
|
key: key.as_str().to_owned(),
|
||||||
|
@ -173,7 +172,6 @@ where
|
||||||
let err = FailedToDeserializePathParams(err);
|
let err = FailedToDeserializePathParams(err);
|
||||||
return Err(err.into());
|
return Err(err.into());
|
||||||
}
|
}
|
||||||
Some(None) => Cow::Owned(Vec::new()),
|
|
||||||
None => {
|
None => {
|
||||||
return Err(MissingPathParams.into());
|
return Err(MissingPathParams.into());
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
||||||
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
||||||
response::{IntoResponse, Redirect, Response},
|
response::{IntoResponse, Redirect, Response},
|
||||||
routing::strip_prefix::StripPrefix,
|
routing::strip_prefix::StripPrefix,
|
||||||
util::{try_downcast, ByteStr, PercentDecodedByteStr},
|
util::try_downcast,
|
||||||
BoxError,
|
BoxError,
|
||||||
};
|
};
|
||||||
use http::{Request, Uri};
|
use http::{Request, Uri};
|
||||||
|
@ -31,6 +31,7 @@ mod method_routing;
|
||||||
mod not_found;
|
mod not_found;
|
||||||
mod route;
|
mod route;
|
||||||
mod strip_prefix;
|
mod strip_prefix;
|
||||||
|
pub(crate) mod url_params;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
@ -446,14 +447,7 @@ where
|
||||||
panic!("should always have a matched path for a route id");
|
panic!("should always have a matched path for a route id");
|
||||||
}
|
}
|
||||||
|
|
||||||
let params = match_
|
url_params::insert_url_params(req.extensions_mut(), match_.params);
|
||||||
.params
|
|
||||||
.iter()
|
|
||||||
.filter(|(key, _)| !key.starts_with(NEST_TAIL_PARAM))
|
|
||||||
.map(|(key, value)| (key.to_owned(), value.to_owned()))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
insert_url_params(&mut req, params);
|
|
||||||
|
|
||||||
let mut route = self
|
let mut route = self
|
||||||
.routes
|
.routes
|
||||||
|
@ -557,49 +551,6 @@ fn with_path(uri: &Uri, new_path: &str) -> Uri {
|
||||||
Uri::from_parts(parts).unwrap()
|
Uri::from_parts(parts).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
// we store the potential error here such that users can handle invalid path
|
|
||||||
// params using `Result<Path<T>, _>`. That wouldn't be possible if we
|
|
||||||
// returned an error immediately when decoding the param
|
|
||||||
pub(crate) struct UrlParams(
|
|
||||||
pub(crate) Result<Vec<(ByteStr, PercentDecodedByteStr)>, InvalidUtf8InPathParam>,
|
|
||||||
);
|
|
||||||
|
|
||||||
fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
|
|
||||||
let params = params
|
|
||||||
.into_iter()
|
|
||||||
.map(|(k, v)| {
|
|
||||||
if let Some(decoded) = PercentDecodedByteStr::new(v) {
|
|
||||||
Ok((ByteStr::new(k), decoded))
|
|
||||||
} else {
|
|
||||||
Err(InvalidUtf8InPathParam {
|
|
||||||
key: ByteStr::new(k),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>, _>>();
|
|
||||||
|
|
||||||
if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
|
|
||||||
match params {
|
|
||||||
Ok(params) => {
|
|
||||||
let mut current = current.take().unwrap();
|
|
||||||
if let Ok(current) = &mut current.0 {
|
|
||||||
current.extend(params);
|
|
||||||
}
|
|
||||||
req.extensions_mut().insert(Some(current));
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
req.extensions_mut().insert(Some(UrlParams(Err(err))));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
req.extensions_mut().insert(Some(UrlParams(params)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) struct InvalidUtf8InPathParam {
|
|
||||||
pub(crate) key: ByteStr,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Wrapper around `matchit::Node` that supports merging two `Node`s.
|
/// Wrapper around `matchit::Node` that supports merging two `Node`s.
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
struct Node {
|
struct Node {
|
||||||
|
|
45
axum/src/routing/url_params.rs
Normal file
45
axum/src/routing/url_params.rs
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
use crate::util::{ByteStr, PercentDecodedByteStr};
|
||||||
|
use http::Extensions;
|
||||||
|
use matchit::Params;
|
||||||
|
|
||||||
|
pub(crate) enum UrlParams {
|
||||||
|
Params(Vec<(ByteStr, PercentDecodedByteStr)>),
|
||||||
|
InvalidUtf8InPathParam { key: ByteStr },
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) {
|
||||||
|
let current_params = extensions.get_mut();
|
||||||
|
|
||||||
|
if let Some(UrlParams::InvalidUtf8InPathParam { .. }) = current_params {
|
||||||
|
// nothing to do here since an error was stored earlier
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let params = params
|
||||||
|
.iter()
|
||||||
|
.filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM))
|
||||||
|
.map(|(key, value)| (key.to_owned(), value.to_owned()))
|
||||||
|
.map(|(k, v)| {
|
||||||
|
if let Some(decoded) = PercentDecodedByteStr::new(v) {
|
||||||
|
Ok((ByteStr::new(k), decoded))
|
||||||
|
} else {
|
||||||
|
Err(ByteStr::new(k))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>();
|
||||||
|
|
||||||
|
match (current_params, params) {
|
||||||
|
(Some(UrlParams::InvalidUtf8InPathParam { .. }), _) => {
|
||||||
|
unreachable!("we check for this state earlier in this method")
|
||||||
|
}
|
||||||
|
(_, Err(invalid_key)) => {
|
||||||
|
extensions.insert(UrlParams::InvalidUtf8InPathParam { key: invalid_key });
|
||||||
|
}
|
||||||
|
(Some(UrlParams::Params(current)), Ok(params)) => {
|
||||||
|
current.extend(params);
|
||||||
|
}
|
||||||
|
(None, Ok(params)) => {
|
||||||
|
extensions.insert(UrlParams::Params(params));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue