From 23f20ea1f3a8e6585d2965f3cfdc1e91768d24a8 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Apr 2022 10:16:38 +0200 Subject: [PATCH] Replace `ByteStr` with `Arc` (#971) --- axum/src/extract/path/de.rs | 21 +++++++-------- axum/src/extract/path/mod.rs | 2 +- axum/src/routing/url_params.rs | 13 ++++----- axum/src/util.rs | 48 +++++----------------------------- 4 files changed, 24 insertions(+), 60 deletions(-) diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index 03ad2000..62236c6e 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -1,10 +1,10 @@ use super::{ErrorKind, PathDeserializationError}; -use crate::util::{ByteStr, PercentDecodedByteStr}; +use crate::util::PercentDecodedStr; use serde::{ de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}, forward_to_deserialize_any, Deserializer, }; -use std::any::type_name; +use std::{any::type_name, sync::Arc}; macro_rules! unsupported_type { ($trait_fn:ident, $name:literal) => { @@ -43,12 +43,12 @@ macro_rules! parse_single_value { } pub(crate) struct PathDeserializer<'de> { - url_params: &'de [(ByteStr, PercentDecodedByteStr)], + url_params: &'de [(Arc, PercentDecodedStr)], } impl<'de> PathDeserializer<'de> { #[inline] - pub(crate) fn new(url_params: &'de [(ByteStr, PercentDecodedByteStr)]) -> Self { + pub(crate) fn new(url_params: &'de [(Arc, PercentDecodedStr)]) -> Self { PathDeserializer { url_params } } } @@ -210,7 +210,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { } struct MapDeserializer<'de> { - params: &'de [(ByteStr, PercentDecodedByteStr)], + params: &'de [(Arc, PercentDecodedStr)], key: Option, value: Option<&'de str>, } @@ -293,7 +293,7 @@ macro_rules! parse_value { if let Some(key) = self.key.take() { let kind = match key { KeyOrIdx::Key(key) => ErrorKind::ParseErrorAtKey { - key: key.as_str().to_owned(), + key: key.to_string(), value: self.value.to_owned(), expected_type: $ty, }, @@ -514,7 +514,7 @@ impl<'de> VariantAccess<'de> for UnitVariant { } struct SeqDeserializer<'de> { - params: &'de [(ByteStr, PercentDecodedByteStr)], + params: &'de [(Arc, PercentDecodedStr)], idx: usize, } @@ -542,14 +542,13 @@ impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { #[derive(Clone)] enum KeyOrIdx { - Key(ByteStr), + Key(Arc), Idx(usize), } #[cfg(test)] mod tests { use super::*; - use crate::util::ByteStr; use serde::Deserialize; use std::collections::HashMap; @@ -568,7 +567,7 @@ mod tests { a: i32, } - fn create_url_params(values: I) -> Vec<(ByteStr, PercentDecodedByteStr)> + fn create_url_params(values: I) -> Vec<(Arc, PercentDecodedStr)> where I: IntoIterator, K: AsRef, @@ -576,7 +575,7 @@ mod tests { { values .into_iter() - .map(|(k, v)| (ByteStr::new(k), PercentDecodedByteStr::new(v).unwrap())) + .map(|(k, v)| (Arc::from(k.as_ref()), PercentDecodedStr::new(v).unwrap())) .collect() } diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index 5c0d331d..39fd3bb1 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -166,7 +166,7 @@ where Some(UrlParams::InvalidUtf8InPathParam { key }) => { let err = PathDeserializationError { kind: ErrorKind::InvalidUtf8InPathParam { - key: key.as_str().to_owned(), + key: key.to_string(), }, }; let err = FailedToDeserializePathParams(err); diff --git a/axum/src/routing/url_params.rs b/axum/src/routing/url_params.rs index c24b08fc..c9f05bb6 100644 --- a/axum/src/routing/url_params.rs +++ b/axum/src/routing/url_params.rs @@ -1,10 +1,11 @@ -use crate::util::{ByteStr, PercentDecodedByteStr}; +use crate::util::PercentDecodedStr; use http::Extensions; use matchit::Params; +use std::sync::Arc; pub(crate) enum UrlParams { - Params(Vec<(ByteStr, PercentDecodedByteStr)>), - InvalidUtf8InPathParam { key: ByteStr }, + Params(Vec<(Arc, PercentDecodedStr)>), + InvalidUtf8InPathParam { key: Arc }, } pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) { @@ -19,10 +20,10 @@ pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) { .iter() .filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM)) .map(|(k, v)| { - if let Some(decoded) = PercentDecodedByteStr::new(v) { - Ok((ByteStr::new(k), decoded)) + if let Some(decoded) = PercentDecodedStr::new(v) { + Ok((Arc::from(k), decoded)) } else { - Err(ByteStr::new(k)) + Err(Arc::from(k)) } }) .collect::, _>>(); diff --git a/axum/src/util.rs b/axum/src/util.rs index 9cb96c4a..b5805e07 100644 --- a/axum/src/util.rs +++ b/axum/src/util.rs @@ -1,46 +1,10 @@ -use crate::body::Bytes; use pin_project_lite::pin_project; -use std::fmt; -use std::ops::Deref; - -/// A string like type backed by `Bytes` making it cheap to clone. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub(crate) struct ByteStr(Bytes); - -impl Deref for ByteStr { - type Target = str; - - #[inline] - fn deref(&self) -> &Self::Target { - self.as_str() - } -} - -impl ByteStr { - pub(crate) fn new(s: S) -> Self - where - S: AsRef, - { - Self(Bytes::copy_from_slice(s.as_ref().as_bytes())) - } - - pub(crate) fn as_str(&self) -> &str { - // `ByteStr` can only be constructed from strings which are always valid - // utf-8 so this wont panic. - std::str::from_utf8(&self.0).unwrap() - } -} - -impl fmt::Display for ByteStr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.as_str()) - } -} +use std::{ops::Deref, sync::Arc}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub(crate) struct PercentDecodedByteStr(ByteStr); +pub(crate) struct PercentDecodedStr(Arc); -impl PercentDecodedByteStr { +impl PercentDecodedStr { pub(crate) fn new(s: S) -> Option where S: AsRef, @@ -48,15 +12,15 @@ impl PercentDecodedByteStr { percent_encoding::percent_decode(s.as_ref().as_bytes()) .decode_utf8() .ok() - .map(|decoded| Self(ByteStr::new(decoded))) + .map(|decoded| Self(decoded.as_ref().into())) } pub(crate) fn as_str(&self) -> &str { - self.0.as_str() + &*self.0 } } -impl Deref for PercentDecodedByteStr { +impl Deref for PercentDecodedStr { type Target = str; #[inline]