From c41c9e0f78274703869a816296d16f8b67c2303f Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 13 Jun 2021 13:06:33 +0200 Subject: [PATCH] Support extracting URL params multiple times (#15) Useful when building higher order extractors. --- src/extract/mod.rs | 26 +++++++++++++------------- src/extract/rejection.rs | 19 ------------------- src/lib.rs | 4 +++- src/routing.rs | 11 ++++++++--- src/tests.rs | 24 ++++++++++++++++++++++++ src/util.rs | 21 +++++++++++++++++++++ 6 files changed, 69 insertions(+), 36 deletions(-) create mode 100644 src/util.rs diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 49e7c25b..1245f409 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -165,7 +165,7 @@ //! # }; //! ``` -use crate::{body::Body, response::IntoResponse}; +use crate::{body::Body, response::IntoResponse, util::ByteStr}; use async_trait::async_trait; use bytes::{Buf, Bytes}; use http::{header, HeaderMap, Method, Request, Uri, Version}; @@ -609,12 +609,12 @@ where /// Note that you can only have one URL params extractor per handler. If you /// have multiple it'll response with `500 Internal Server Error`. #[derive(Debug)] -pub struct UrlParamsMap(HashMap); +pub struct UrlParamsMap(HashMap); impl UrlParamsMap { /// Look up the value for a key. pub fn get(&self, key: &str) -> Option<&str> { - self.0.get(key).map(|s| &**s) + self.0.get(&ByteStr::new(key)).map(|s| s.as_str()) } /// Look up the value for a key and parse it into a value of type `T`. @@ -628,20 +628,20 @@ impl UrlParamsMap { #[async_trait] impl FromRequest for UrlParamsMap { - type Rejection = UrlParamsMapRejection; + type Rejection = MissingRouteParams; async fn from_request(req: &mut Request) -> Result { if let Some(params) = req .extensions_mut() .get_mut::>() { - if let Some(params) = params.take() { - Ok(Self(params.0.into_iter().collect())) + if let Some(params) = params { + Ok(Self(params.0.iter().cloned().collect())) } else { - Err(UrlParamsAlreadyExtracted.into()) + Ok(Self(Default::default())) } } else { - Err(MissingRouteParams.into()) + Err(MissingRouteParams) } } } @@ -689,24 +689,24 @@ macro_rules! impl_parse_url { .extensions_mut() .get_mut::>() { - if let Some(params) = params.take() { - params.0 + if let Some(params) = params { + params.0.clone() } else { - return Err(UrlParamsAlreadyExtracted.into()); + Default::default() } } else { return Err(MissingRouteParams.into()) }; if let [(_, $head), $((_, $tail),)*] = &*params { - let $head = if let Ok(x) = $head.parse::<$head>() { + let $head = if let Ok(x) = $head.as_str().parse::<$head>() { x } else { return Err(InvalidUrlParam::new::<$head>().into()); }; $( - let $tail = if let Ok(x) = $tail.parse::<$tail>() { + let $tail = if let Ok(x) = $tail.as_str().parse::<$tail>() { x } else { return Err(InvalidUrlParam::new::<$tail>().into()); diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index 8f27efee..df4e6b34 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -127,13 +127,6 @@ define_rejection! { pub struct MissingRouteParams; } -define_rejection! { - #[status = INTERNAL_SERVER_ERROR] - #[body = "Cannot have two URL capture extractors for a single handler"] - /// Rejection type used if you try and extract the URL params more than once. - pub struct UrlParamsAlreadyExtracted; -} - define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "Cannot have two request body extractors for a single handler"] @@ -288,17 +281,6 @@ composite_rejection! { } } -composite_rejection! { - /// Rejection used for [`UrlParamsMap`](super::UrlParamsMap). - /// - /// Contains one variant for each way the [`UrlParamsMap`](super::UrlParamsMap) extractor - /// can fail. - pub enum UrlParamsMapRejection { - UrlParamsAlreadyExtracted, - MissingRouteParams, - } -} - composite_rejection! { /// Rejection used for [`UrlParams`](super::UrlParams). /// @@ -306,7 +288,6 @@ composite_rejection! { /// can fail. pub enum UrlParamsRejection { InvalidUrlParam, - UrlParamsAlreadyExtracted, MissingRouteParams, } } diff --git a/src/lib.rs b/src/lib.rs index 70f1703c..52861a15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -601,7 +601,7 @@ clippy::match_like_matches_macro, clippy::type_complexity )] -#![forbid(unsafe_code)] +#![deny(unsafe_code)] #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(test, allow(clippy::float_cmp))] @@ -614,6 +614,8 @@ use tower::Service; #[macro_use] pub(crate) mod macros; +mod util; + pub mod body; pub mod extract; pub mod handler; diff --git a/src/routing.rs b/src/routing.rs index 8eb10746..0d1c5e09 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -1,6 +1,6 @@ //! Routing between [`Service`]s. -use crate::{body::BoxBody, response::IntoResponse}; +use crate::{body::BoxBody, response::IntoResponse, util::ByteStr}; use async_trait::async_trait; use bytes::Bytes; use futures_util::{future, ready}; @@ -365,15 +365,20 @@ where } #[derive(Debug)] -pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>); +pub(crate) struct UrlParams(pub(crate) Vec<(ByteStr, ByteStr)>); fn insert_url_params(req: &mut Request, params: Vec<(String, String)>) { + let params = params + .into_iter() + .map(|(k, v)| (ByteStr::new(k), ByteStr::new(v))); + if let Some(current) = req.extensions_mut().get_mut::>() { let mut current = current.take().unwrap(); current.0.extend(params); req.extensions_mut().insert(Some(current)); } else { - req.extensions_mut().insert(Some(UrlParams(params))); + req.extensions_mut() + .insert(Some(UrlParams(params.collect()))); } } diff --git a/src/tests.rs b/src/tests.rs index fbb3c973..c0877a24 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -271,6 +271,30 @@ async fn extracting_url_params() { assert_eq!(res.status(), StatusCode::OK); } +#[tokio::test] +async fn extracting_url_params_multiple_times() { + let app = route( + "/users/:id", + get( + |_: extract::UrlParams<(i32,)>, + _: extract::UrlParamsMap, + _: extract::UrlParams<(i32,)>, + _: extract::UrlParamsMap| async {}, + ), + ); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/users/42", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + #[tokio::test] async fn boxing() { let app = route( diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 00000000..5e129b23 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,21 @@ +use bytes::Bytes; + +/// A string like type backed by `Bytes` making it cheap to clone. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub(crate) struct ByteStr(Bytes); + +impl ByteStr { + pub(crate) fn new(s: S) -> Self + where + S: AsRef, + { + Self(Bytes::copy_from_slice(s.as_ref().as_bytes())) + } + + #[allow(unsafe_code)] + pub(crate) fn as_str(&self) -> &str { + // SAFETY: `ByteStr` can only be constructed from strings which are + // always valid utf-8. + unsafe { std::str::from_utf8_unchecked(&self.0) } + } +}