Refactor storing URL params in extensions (#833)

This commit is contained in:
David Pedersen 2022-03-06 12:41:16 +01:00 committed by GitHub
parent 79b94b9bd6
commit a438e6b106
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 58 deletions

View file

@ -5,14 +5,13 @@ mod de;
use crate::{
extract::{rejection::*, FromRequest, RequestParts},
routing::{InvalidUtf8InPathParam, UrlParams},
routing::url_params::UrlParams,
};
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
use http::StatusCode;
use serde::de::DeserializeOwned;
use std::{
borrow::Cow,
fmt,
ops::{Deref, DerefMut},
};
@ -162,9 +161,9 @@ where
type Rejection = PathRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let params = match req.extensions_mut().get::<Option<UrlParams>>() {
Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params),
Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => {
let params = match req.extensions_mut().get::<UrlParams>() {
Some(UrlParams::Params(params)) => params,
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
let err = PathDeserializationError {
kind: ErrorKind::InvalidUtf8InPathParam {
key: key.as_str().to_owned(),
@ -173,7 +172,6 @@ where
let err = FailedToDeserializePathParams(err);
return Err(err.into());
}
Some(None) => Cow::Owned(Vec::new()),
None => {
return Err(MissingPathParams.into());
}

View file

@ -6,7 +6,7 @@ use crate::{
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
response::{IntoResponse, Redirect, Response},
routing::strip_prefix::StripPrefix,
util::{try_downcast, ByteStr, PercentDecodedByteStr},
util::try_downcast,
BoxError,
};
use http::{Request, Uri};
@ -31,6 +31,7 @@ mod method_routing;
mod not_found;
mod route;
mod strip_prefix;
pub(crate) mod url_params;
#[cfg(test)]
mod tests;
@ -446,14 +447,7 @@ where
panic!("should always have a matched path for a route id");
}
let params = match_
.params
.iter()
.filter(|(key, _)| !key.starts_with(NEST_TAIL_PARAM))
.map(|(key, value)| (key.to_owned(), value.to_owned()))
.collect::<Vec<_>>();
insert_url_params(&mut req, params);
url_params::insert_url_params(req.extensions_mut(), match_.params);
let mut route = self
.routes
@ -557,49 +551,6 @@ fn with_path(uri: &Uri, new_path: &str) -> Uri {
Uri::from_parts(parts).unwrap()
}
// we store the potential error here such that users can handle invalid path
// params using `Result<Path<T>, _>`. That wouldn't be possible if we
// returned an error immediately when decoding the param
pub(crate) struct UrlParams(
pub(crate) Result<Vec<(ByteStr, PercentDecodedByteStr)>, InvalidUtf8InPathParam>,
);
fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
let params = params
.into_iter()
.map(|(k, v)| {
if let Some(decoded) = PercentDecodedByteStr::new(v) {
Ok((ByteStr::new(k), decoded))
} else {
Err(InvalidUtf8InPathParam {
key: ByteStr::new(k),
})
}
})
.collect::<Result<Vec<_>, _>>();
if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
match params {
Ok(params) => {
let mut current = current.take().unwrap();
if let Ok(current) = &mut current.0 {
current.extend(params);
}
req.extensions_mut().insert(Some(current));
}
Err(err) => {
req.extensions_mut().insert(Some(UrlParams(Err(err))));
}
}
} else {
req.extensions_mut().insert(Some(UrlParams(params)));
}
}
pub(crate) struct InvalidUtf8InPathParam {
pub(crate) key: ByteStr,
}
/// Wrapper around `matchit::Node` that supports merging two `Node`s.
#[derive(Clone, Default)]
struct Node {

View file

@ -0,0 +1,45 @@
use crate::util::{ByteStr, PercentDecodedByteStr};
use http::Extensions;
use matchit::Params;
pub(crate) enum UrlParams {
Params(Vec<(ByteStr, PercentDecodedByteStr)>),
InvalidUtf8InPathParam { key: ByteStr },
}
pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) {
let current_params = extensions.get_mut();
if let Some(UrlParams::InvalidUtf8InPathParam { .. }) = current_params {
// nothing to do here since an error was stored earlier
return;
}
let params = params
.iter()
.filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM))
.map(|(key, value)| (key.to_owned(), value.to_owned()))
.map(|(k, v)| {
if let Some(decoded) = PercentDecodedByteStr::new(v) {
Ok((ByteStr::new(k), decoded))
} else {
Err(ByteStr::new(k))
}
})
.collect::<Result<Vec<_>, _>>();
match (current_params, params) {
(Some(UrlParams::InvalidUtf8InPathParam { .. }), _) => {
unreachable!("we check for this state earlier in this method")
}
(_, Err(invalid_key)) => {
extensions.insert(UrlParams::InvalidUtf8InPathParam { key: invalid_key });
}
(Some(UrlParams::Params(current)), Ok(params)) => {
current.extend(params);
}
(None, Ok(params)) => {
extensions.insert(UrlParams::Params(params));
}
}
}