diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 324b3de4..3bcfc3e3 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -52,15 +52,15 @@ struct State { async fn get( _req: Request<Body>, - params: extract::UrlParamsMap, + params: extract::UrlParams<(String,)>, state: extract::Extension<SharedState>, ) -> Result<Bytes, Error> { let state = state.into_inner(); let db = &state.lock().unwrap().db; - let key = params.get("key")?; + let (key,) = params.into_inner(); - if let Some(value) = db.get(key) { + if let Some(value) = db.get(&key) { Ok(value.clone()) } else { Err(Error::Status(StatusCode::NOT_FOUND)) @@ -69,14 +69,14 @@ async fn get( async fn set( _req: Request<Body>, - params: extract::UrlParamsMap, + params: extract::UrlParams<(String,)>, value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb state: extract::Extension<SharedState>, ) -> Result<response::Empty, Error> { let state = state.into_inner(); let db = &mut state.lock().unwrap().db; - let key = params.get("key")?; + let (key,) = params.into_inner(); let value = value.into_inner(); db.insert(key.to_string(), value); diff --git a/src/extract.rs b/src/extract.rs index d1c6b41c..b2270e84 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -197,7 +197,7 @@ impl UrlParamsMap { pub fn get_typed<T>(&self, key: &str) -> Result<T, Error> where - T: std::str::FromStr, + T: FromStr, { self.get(key)?.parse().map_err(|_| Error::InvalidUrlParam { type_name: std::any::type_name::<T>(), @@ -220,3 +220,65 @@ impl FromRequest for UrlParamsMap { } } } + +pub struct UrlParams<T>(T); + +impl<T> UrlParams<T> { + pub fn into_inner(self) -> T { + self.0 + } +} + +macro_rules! impl_parse_url { + () => {}; + + ( $head:ident, $($tail:ident),* $(,)? ) => { + impl<$head, $($tail,)*> FromRequest for UrlParams<($head, $($tail,)*)> + where + $head: FromStr + Send, + $( $tail: FromStr + Send, )* + { + type Future = future::Ready<Result<Self, Error>>; + + #[allow(non_snake_case)] + fn from_request(req: &mut Request<Body>) -> Self::Future { + let params = if let Some(params) = req + .extensions_mut() + .get_mut::<Option<crate::routing::UrlParams>>() + { + params.take().expect("params already taken").0 + } else { + panic!("no url params found for matched route. This is a bug in tower-web") + }; + + if let [(_, $head), $((_, $tail),)*] = &*params { + let $head = if let Ok(x) = $head.parse::<$head>() { + x + } else { + return future::err(Error::InvalidUrlParam { + type_name: std::any::type_name::<$head>(), + }); + }; + + $( + let $tail = if let Ok(x) = $tail.parse::<$tail>() { + x + } else { + return future::err(Error::InvalidUrlParam { + type_name: std::any::type_name::<$tail>(), + }); + }; + )* + + future::ok(UrlParams(($head, $($tail,)*))) + } else { + panic!("wrong number of url params found for matched route. This is a bug in tower-web") + } + } + } + + impl_parse_url!($($tail,)*); + }; +} + +impl_parse_url!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); diff --git a/src/handler.rs b/src/handler.rs index 7c651bff..7c1787e7 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -51,27 +51,7 @@ where } macro_rules! impl_handler { - ( $head:ident $(,)? ) => { - #[async_trait] - #[allow(non_snake_case)] - impl<F, Fut, B, Res, $head> Handler<B, ($head,)> for F - where - F: Fn(Request<Body>, $head) -> Fut + Send + Sync, - Fut: Future<Output = Result<Res, Error>> + Send, - Res: IntoResponse<B>, - $head: FromRequest + Send, - { - type Response = Res; - - type Sealed = sealed::Hidden; - - async fn call(self, mut req: Request<Body>) -> Result<Self::Response, Error> { - let $head = $head::from_request(&mut req).await?; - let res = self(req, $head).await?; - Ok(res) - } - } - }; + () => {}; ( $head:ident, $($tail:ident),* $(,)? ) => { #[async_trait]