mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-22 07:08:16 +01:00
Support extracting URL params multiple times (#15)
Useful when building higher order extractors.
This commit is contained in:
parent
2b360a7873
commit
c41c9e0f78
6 changed files with 69 additions and 36 deletions
|
@ -165,7 +165,7 @@
|
|||
//! # };
|
||||
//! ```
|
||||
|
||||
use crate::{body::Body, response::IntoResponse};
|
||||
use crate::{body::Body, response::IntoResponse, util::ByteStr};
|
||||
use async_trait::async_trait;
|
||||
use bytes::{Buf, Bytes};
|
||||
use http::{header, HeaderMap, Method, Request, Uri, Version};
|
||||
|
@ -609,12 +609,12 @@ where
|
|||
/// Note that you can only have one URL params extractor per handler. If you
|
||||
/// have multiple it'll response with `500 Internal Server Error`.
|
||||
#[derive(Debug)]
|
||||
pub struct UrlParamsMap(HashMap<String, String>);
|
||||
pub struct UrlParamsMap(HashMap<ByteStr, ByteStr>);
|
||||
|
||||
impl UrlParamsMap {
|
||||
/// Look up the value for a key.
|
||||
pub fn get(&self, key: &str) -> Option<&str> {
|
||||
self.0.get(key).map(|s| &**s)
|
||||
self.0.get(&ByteStr::new(key)).map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Look up the value for a key and parse it into a value of type `T`.
|
||||
|
@ -628,20 +628,20 @@ impl UrlParamsMap {
|
|||
|
||||
#[async_trait]
|
||||
impl FromRequest for UrlParamsMap {
|
||||
type Rejection = UrlParamsMapRejection;
|
||||
type Rejection = MissingRouteParams;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
if let Some(params) = req
|
||||
.extensions_mut()
|
||||
.get_mut::<Option<crate::routing::UrlParams>>()
|
||||
{
|
||||
if let Some(params) = params.take() {
|
||||
Ok(Self(params.0.into_iter().collect()))
|
||||
if let Some(params) = params {
|
||||
Ok(Self(params.0.iter().cloned().collect()))
|
||||
} else {
|
||||
Err(UrlParamsAlreadyExtracted.into())
|
||||
Ok(Self(Default::default()))
|
||||
}
|
||||
} else {
|
||||
Err(MissingRouteParams.into())
|
||||
Err(MissingRouteParams)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -689,24 +689,24 @@ macro_rules! impl_parse_url {
|
|||
.extensions_mut()
|
||||
.get_mut::<Option<crate::routing::UrlParams>>()
|
||||
{
|
||||
if let Some(params) = params.take() {
|
||||
params.0
|
||||
if let Some(params) = params {
|
||||
params.0.clone()
|
||||
} else {
|
||||
return Err(UrlParamsAlreadyExtracted.into());
|
||||
Default::default()
|
||||
}
|
||||
} else {
|
||||
return Err(MissingRouteParams.into())
|
||||
};
|
||||
|
||||
if let [(_, $head), $((_, $tail),)*] = &*params {
|
||||
let $head = if let Ok(x) = $head.parse::<$head>() {
|
||||
let $head = if let Ok(x) = $head.as_str().parse::<$head>() {
|
||||
x
|
||||
} else {
|
||||
return Err(InvalidUrlParam::new::<$head>().into());
|
||||
};
|
||||
|
||||
$(
|
||||
let $tail = if let Ok(x) = $tail.parse::<$tail>() {
|
||||
let $tail = if let Ok(x) = $tail.as_str().parse::<$tail>() {
|
||||
x
|
||||
} else {
|
||||
return Err(InvalidUrlParam::new::<$tail>().into());
|
||||
|
|
|
@ -127,13 +127,6 @@ define_rejection! {
|
|||
pub struct MissingRouteParams;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = INTERNAL_SERVER_ERROR]
|
||||
#[body = "Cannot have two URL capture extractors for a single handler"]
|
||||
/// Rejection type used if you try and extract the URL params more than once.
|
||||
pub struct UrlParamsAlreadyExtracted;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = INTERNAL_SERVER_ERROR]
|
||||
#[body = "Cannot have two request body extractors for a single handler"]
|
||||
|
@ -288,17 +281,6 @@ composite_rejection! {
|
|||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`UrlParamsMap`](super::UrlParamsMap).
|
||||
///
|
||||
/// Contains one variant for each way the [`UrlParamsMap`](super::UrlParamsMap) extractor
|
||||
/// can fail.
|
||||
pub enum UrlParamsMapRejection {
|
||||
UrlParamsAlreadyExtracted,
|
||||
MissingRouteParams,
|
||||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`UrlParams`](super::UrlParams).
|
||||
///
|
||||
|
@ -306,7 +288,6 @@ composite_rejection! {
|
|||
/// can fail.
|
||||
pub enum UrlParamsRejection {
|
||||
InvalidUrlParam,
|
||||
UrlParamsAlreadyExtracted,
|
||||
MissingRouteParams,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -601,7 +601,7 @@
|
|||
clippy::match_like_matches_macro,
|
||||
clippy::type_complexity
|
||||
)]
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(unsafe_code)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![cfg_attr(test, allow(clippy::float_cmp))]
|
||||
|
||||
|
@ -614,6 +614,8 @@ use tower::Service;
|
|||
#[macro_use]
|
||||
pub(crate) mod macros;
|
||||
|
||||
mod util;
|
||||
|
||||
pub mod body;
|
||||
pub mod extract;
|
||||
pub mod handler;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
//! Routing between [`Service`]s.
|
||||
|
||||
use crate::{body::BoxBody, response::IntoResponse};
|
||||
use crate::{body::BoxBody, response::IntoResponse, util::ByteStr};
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use futures_util::{future, ready};
|
||||
|
@ -365,15 +365,20 @@ where
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>);
|
||||
pub(crate) struct UrlParams(pub(crate) Vec<(ByteStr, ByteStr)>);
|
||||
|
||||
fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
|
||||
let params = params
|
||||
.into_iter()
|
||||
.map(|(k, v)| (ByteStr::new(k), ByteStr::new(v)));
|
||||
|
||||
if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
|
||||
let mut current = current.take().unwrap();
|
||||
current.0.extend(params);
|
||||
req.extensions_mut().insert(Some(current));
|
||||
} else {
|
||||
req.extensions_mut().insert(Some(UrlParams(params)));
|
||||
req.extensions_mut()
|
||||
.insert(Some(UrlParams(params.collect())));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
24
src/tests.rs
24
src/tests.rs
|
@ -271,6 +271,30 @@ async fn extracting_url_params() {
|
|||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extracting_url_params_multiple_times() {
|
||||
let app = route(
|
||||
"/users/:id",
|
||||
get(
|
||||
|_: extract::UrlParams<(i32,)>,
|
||||
_: extract::UrlParamsMap,
|
||||
_: extract::UrlParams<(i32,)>,
|
||||
_: extract::UrlParamsMap| async {},
|
||||
),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/users/42", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn boxing() {
|
||||
let app = route(
|
||||
|
|
21
src/util.rs
Normal file
21
src/util.rs
Normal file
|
@ -0,0 +1,21 @@
|
|||
use bytes::Bytes;
|
||||
|
||||
/// A string like type backed by `Bytes` making it cheap to clone.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct ByteStr(Bytes);
|
||||
|
||||
impl ByteStr {
|
||||
pub(crate) fn new<S>(s: S) -> Self
|
||||
where
|
||||
S: AsRef<str>,
|
||||
{
|
||||
Self(Bytes::copy_from_slice(s.as_ref().as_bytes()))
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
pub(crate) fn as_str(&self) -> &str {
|
||||
// SAFETY: `ByteStr` can only be constructed from strings which are
|
||||
// always valid utf-8.
|
||||
unsafe { std::str::from_utf8_unchecked(&self.0) }
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue