mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-08 08:56:24 +01:00
parent
2be79168d8
commit
9fdbd42fba
14 changed files with 871 additions and 66 deletions
|
@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Use `pin-project-lite` instead of `pin-project`. ([#95](https://github.com/tokio-rs/axum/pull/95))
|
||||
- Re-export `http` crate and `hyper::Server`. ([#110](https://github.com/tokio-rs/axum/pull/110))
|
||||
- Fix `Query` and `Form` extractors giving bad request error when query string is empty. ([#117](https://github.com/tokio-rs/axum/pull/117))
|
||||
- Add `Path` extractor. ([#124](https://github.com/tokio-rs/axum/pull/124))
|
||||
|
||||
## Breaking changes
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::{Extension, Json, UrlParams},
|
||||
extract::{Extension, Json, Path},
|
||||
prelude::*,
|
||||
response::IntoResponse,
|
||||
AddExtensionLayer,
|
||||
|
@ -56,7 +56,7 @@ async fn main() {
|
|||
/// are automatically converted into `AppError` which implements `IntoResponse`
|
||||
/// so it can be returned from handlers directly.
|
||||
async fn users_show(
|
||||
UrlParams((user_id,)): UrlParams<(Uuid,)>,
|
||||
Path(user_id): Path<Uuid>,
|
||||
Extension(user_repo): Extension<DynUserRepo>,
|
||||
) -> Result<response::Json<User>, AppError> {
|
||||
let user = user_repo.find(user_id).await?;
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::{extractor_middleware, ContentLengthLimit, Extension, RequestParts, UrlParams},
|
||||
extract::{extractor_middleware, ContentLengthLimit, Extension, Path, RequestParts},
|
||||
prelude::*,
|
||||
response::IntoResponse,
|
||||
routing::BoxRoute,
|
||||
|
@ -79,7 +79,7 @@ struct State {
|
|||
}
|
||||
|
||||
async fn kv_get(
|
||||
UrlParams((key,)): UrlParams<(String,)>,
|
||||
Path(key): Path<String>,
|
||||
Extension(state): Extension<SharedState>,
|
||||
) -> Result<Bytes, StatusCode> {
|
||||
let db = &state.read().unwrap().db;
|
||||
|
@ -92,7 +92,7 @@ async fn kv_get(
|
|||
}
|
||||
|
||||
async fn kv_set(
|
||||
UrlParams((key,)): UrlParams<(String,)>,
|
||||
Path(key): Path<String>,
|
||||
ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb
|
||||
Extension(state): Extension<SharedState>,
|
||||
) {
|
||||
|
@ -113,10 +113,7 @@ fn admin_routes() -> BoxRoute<hyper::Body> {
|
|||
state.write().unwrap().db.clear();
|
||||
}
|
||||
|
||||
async fn remove_key(
|
||||
UrlParams((key,)): UrlParams<(String,)>,
|
||||
Extension(state): Extension<SharedState>,
|
||||
) {
|
||||
async fn remove_key(Path(key): Path<String>, Extension(state): Extension<SharedState>) {
|
||||
state.write().unwrap().db.remove(&key);
|
||||
}
|
||||
|
||||
|
|
|
@ -29,14 +29,8 @@ async fn main() {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
async fn greet(params: extract::UrlParamsMap) -> impl IntoResponse {
|
||||
let name = params
|
||||
.get("name")
|
||||
.expect("`name` will be there if route was matched")
|
||||
.to_string();
|
||||
|
||||
async fn greet(extract::Path(name): extract::Path<String>) -> impl IntoResponse {
|
||||
let template = HelloTemplate { name };
|
||||
|
||||
HtmlTemplate(template)
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
//! ```
|
||||
|
||||
use axum::{
|
||||
extract::{Extension, Json, Query, UrlParams},
|
||||
extract::{Extension, Json, Path, Query},
|
||||
prelude::*,
|
||||
response::IntoResponse,
|
||||
service::ServiceExt,
|
||||
|
@ -129,7 +129,7 @@ struct UpdateTodo {
|
|||
}
|
||||
|
||||
async fn todos_update(
|
||||
UrlParams((id,)): UrlParams<(Uuid,)>,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(input): Json<UpdateTodo>,
|
||||
Extension(db): Extension<Db>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
|
@ -153,10 +153,7 @@ async fn todos_update(
|
|||
Ok(response::Json(todo))
|
||||
}
|
||||
|
||||
async fn todos_delete(
|
||||
UrlParams((id,)): UrlParams<(Uuid,)>,
|
||||
Extension(db): Extension<Db>,
|
||||
) -> impl IntoResponse {
|
||||
async fn todos_delete(Path(id): Path<Uuid>, Extension(db): Extension<Db>) -> impl IntoResponse {
|
||||
if db.write().unwrap().remove(&id).is_some() {
|
||||
StatusCode::NO_CONTENT
|
||||
} else {
|
||||
|
|
|
@ -12,6 +12,7 @@ use axum::{
|
|||
};
|
||||
use http::Response;
|
||||
use http::StatusCode;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
#[tokio::main]
|
||||
|
@ -53,7 +54,7 @@ where
|
|||
type Rejection = Response<Body>;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let params = extract::UrlParamsMap::from_request(req)
|
||||
let params = extract::Path::<HashMap<String, String>>::from_request(req)
|
||||
.await
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
|
||||
|
@ -61,7 +62,7 @@ where
|
|||
.get("version")
|
||||
.ok_or_else(|| (StatusCode::NOT_FOUND, "version param missing").into_response())?;
|
||||
|
||||
match version {
|
||||
match version.as_str() {
|
||||
"v1" => Ok(Version::V1),
|
||||
"v2" => Ok(Version::V2),
|
||||
"v3" => Ok(Version::V3),
|
||||
|
|
|
@ -258,6 +258,7 @@ mod content_length_limit;
|
|||
mod extension;
|
||||
mod form;
|
||||
mod json;
|
||||
mod path;
|
||||
mod query;
|
||||
mod raw_query;
|
||||
mod request_parts;
|
||||
|
@ -273,6 +274,7 @@ pub use self::{
|
|||
extractor_middleware::extractor_middleware,
|
||||
form::Form,
|
||||
json::Json,
|
||||
path::Path,
|
||||
query::Query,
|
||||
raw_query::RawQuery,
|
||||
request_parts::{Body, BodyStream},
|
||||
|
|
671
src/extract/path/de.rs
Normal file
671
src/extract/path/de.rs
Normal file
|
@ -0,0 +1,671 @@
|
|||
use crate::routing::UrlParams;
|
||||
use crate::util::ByteStr;
|
||||
use serde::{
|
||||
de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor},
|
||||
forward_to_deserialize_any, Deserializer,
|
||||
};
|
||||
use std::fmt::{self, Display};
|
||||
|
||||
/// This type represents errors that can occur when deserializing.
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
pub(crate) struct PathDeserializerError(pub(crate) String);
|
||||
|
||||
impl de::Error for PathDeserializerError {
|
||||
#[inline]
|
||||
fn custom<T: Display>(msg: T) -> Self {
|
||||
PathDeserializerError(msg.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PathDeserializerError {
|
||||
#[inline]
|
||||
fn description(&self) -> &str {
|
||||
"path deserializer error"
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PathDeserializerError {
|
||||
#[inline]
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
PathDeserializerError(msg) => write!(f, "{}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! unsupported_type {
|
||||
($trait_fn:ident, $name:literal) => {
|
||||
fn $trait_fn<V>(self, _: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
Err(PathDeserializerError::custom(concat!(
|
||||
"unsupported type: ",
|
||||
$name
|
||||
)))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! parse_single_value {
|
||||
($trait_fn:ident, $visit_fn:ident, $tp:literal) => {
|
||||
fn $trait_fn<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if self.url_params.0.len() != 1 {
|
||||
return Err(PathDeserializerError::custom(
|
||||
format!(
|
||||
"wrong number of parameters: {} expected 1",
|
||||
self.url_params.0.len()
|
||||
)
|
||||
.as_str(),
|
||||
));
|
||||
}
|
||||
|
||||
let value = self.url_params.0[0].1.parse().map_err(|_| {
|
||||
PathDeserializerError::custom(format!(
|
||||
"can not parse `{:?}` to a `{}`",
|
||||
self.url_params.0[0].1.as_str(),
|
||||
$tp
|
||||
))
|
||||
})?;
|
||||
visitor.$visit_fn(value)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) struct PathDeserializer<'de> {
|
||||
url_params: &'de UrlParams,
|
||||
}
|
||||
|
||||
impl<'de> PathDeserializer<'de> {
|
||||
#[inline]
|
||||
pub(crate) fn new(url_params: &'de UrlParams) -> Self {
|
||||
PathDeserializer { url_params }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserializer<'de> for PathDeserializer<'de> {
|
||||
type Error = PathDeserializerError;
|
||||
|
||||
unsupported_type!(deserialize_any, "'any'");
|
||||
unsupported_type!(deserialize_bytes, "bytes");
|
||||
unsupported_type!(deserialize_option, "Option<T>");
|
||||
unsupported_type!(deserialize_identifier, "identifier");
|
||||
unsupported_type!(deserialize_ignored_any, "ignored_any");
|
||||
|
||||
parse_single_value!(deserialize_bool, visit_bool, "bool");
|
||||
parse_single_value!(deserialize_i8, visit_i8, "i8");
|
||||
parse_single_value!(deserialize_i16, visit_i16, "i16");
|
||||
parse_single_value!(deserialize_i32, visit_i32, "i32");
|
||||
parse_single_value!(deserialize_i64, visit_i64, "i64");
|
||||
parse_single_value!(deserialize_u8, visit_u8, "u8");
|
||||
parse_single_value!(deserialize_u16, visit_u16, "u16");
|
||||
parse_single_value!(deserialize_u32, visit_u32, "u32");
|
||||
parse_single_value!(deserialize_u64, visit_u64, "u64");
|
||||
parse_single_value!(deserialize_f32, visit_f32, "f32");
|
||||
parse_single_value!(deserialize_f64, visit_f64, "f64");
|
||||
parse_single_value!(deserialize_string, visit_string, "String");
|
||||
parse_single_value!(deserialize_byte_buf, visit_string, "String");
|
||||
parse_single_value!(deserialize_char, visit_char, "char");
|
||||
|
||||
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if self.url_params.0.len() != 1 {
|
||||
return Err(PathDeserializerError::custom(format!(
|
||||
"wrong number of parameters: {} expected 1",
|
||||
self.url_params.0.len()
|
||||
)));
|
||||
}
|
||||
visitor.visit_str(&self.url_params.0[0].1)
|
||||
}
|
||||
|
||||
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_unit()
|
||||
}
|
||||
|
||||
fn deserialize_unit_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_unit()
|
||||
}
|
||||
|
||||
fn deserialize_newtype_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_newtype_struct(self)
|
||||
}
|
||||
|
||||
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_seq(SeqDeserializer {
|
||||
params: &self.url_params.0,
|
||||
})
|
||||
}
|
||||
|
||||
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if self.url_params.0.len() < len {
|
||||
return Err(PathDeserializerError::custom(
|
||||
format!(
|
||||
"wrong number of parameters: {} expected {}",
|
||||
self.url_params.0.len(),
|
||||
len
|
||||
)
|
||||
.as_str(),
|
||||
));
|
||||
}
|
||||
visitor.visit_seq(SeqDeserializer {
|
||||
params: &self.url_params.0,
|
||||
})
|
||||
}
|
||||
|
||||
fn deserialize_tuple_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
len: usize,
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if self.url_params.0.len() < len {
|
||||
return Err(PathDeserializerError::custom(
|
||||
format!(
|
||||
"wrong number of parameters: {} expected {}",
|
||||
self.url_params.0.len(),
|
||||
len
|
||||
)
|
||||
.as_str(),
|
||||
));
|
||||
}
|
||||
visitor.visit_seq(SeqDeserializer {
|
||||
params: &self.url_params.0,
|
||||
})
|
||||
}
|
||||
|
||||
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_map(MapDeserializer {
|
||||
params: &self.url_params.0,
|
||||
value: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn deserialize_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_fields: &'static [&'static str],
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
self.deserialize_map(visitor)
|
||||
}
|
||||
|
||||
fn deserialize_enum<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variants: &'static [&'static str],
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
if self.url_params.0.len() != 1 {
|
||||
return Err(PathDeserializerError::custom(format!(
|
||||
"wrong number of parameters: {} expected 1",
|
||||
self.url_params.0.len()
|
||||
)));
|
||||
}
|
||||
|
||||
visitor.visit_enum(EnumDeserializer {
|
||||
value: &self.url_params.0[0].1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct MapDeserializer<'de> {
|
||||
params: &'de [(ByteStr, ByteStr)],
|
||||
value: Option<&'de str>,
|
||||
}
|
||||
|
||||
impl<'de> MapAccess<'de> for MapDeserializer<'de> {
|
||||
type Error = PathDeserializerError;
|
||||
|
||||
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
|
||||
where
|
||||
K: DeserializeSeed<'de>,
|
||||
{
|
||||
match self.params.split_first() {
|
||||
Some(((key, value), tail)) => {
|
||||
self.value = Some(value);
|
||||
self.params = tail;
|
||||
seed.deserialize(KeyDeserializer { key }).map(Some)
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: DeserializeSeed<'de>,
|
||||
{
|
||||
match self.value.take() {
|
||||
Some(value) => seed.deserialize(ValueDeserializer { value }),
|
||||
None => Err(serde::de::Error::custom("value is missing")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct KeyDeserializer<'de> {
|
||||
key: &'de str,
|
||||
}
|
||||
|
||||
macro_rules! parse_key {
|
||||
($trait_fn:ident) => {
|
||||
fn $trait_fn<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_str(self.key)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl<'de> Deserializer<'de> for KeyDeserializer<'de> {
|
||||
type Error = PathDeserializerError;
|
||||
|
||||
parse_key!(deserialize_identifier);
|
||||
parse_key!(deserialize_str);
|
||||
parse_key!(deserialize_string);
|
||||
|
||||
fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
Err(PathDeserializerError::custom("Unexpected"))
|
||||
}
|
||||
|
||||
forward_to_deserialize_any! {
|
||||
bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char bytes
|
||||
byte_buf option unit unit_struct seq tuple
|
||||
tuple_struct map newtype_struct struct enum ignored_any
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! parse_value {
|
||||
($trait_fn:ident, $visit_fn:ident, $ty:literal) => {
|
||||
fn $trait_fn<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
let v = self.value.parse().map_err(|_| {
|
||||
PathDeserializerError::custom(format!(
|
||||
"can not parse `{:?}` to a `{}`",
|
||||
self.value, $ty
|
||||
))
|
||||
})?;
|
||||
visitor.$visit_fn(v)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
struct ValueDeserializer<'de> {
|
||||
value: &'de str,
|
||||
}
|
||||
|
||||
impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
|
||||
type Error = PathDeserializerError;
|
||||
|
||||
unsupported_type!(deserialize_any, "any");
|
||||
unsupported_type!(deserialize_seq, "seq");
|
||||
unsupported_type!(deserialize_map, "map");
|
||||
unsupported_type!(deserialize_identifier, "identifier");
|
||||
|
||||
parse_value!(deserialize_bool, visit_bool, "bool");
|
||||
parse_value!(deserialize_i8, visit_i8, "i8");
|
||||
parse_value!(deserialize_i16, visit_i16, "i16");
|
||||
parse_value!(deserialize_i32, visit_i32, "i16");
|
||||
parse_value!(deserialize_i64, visit_i64, "i64");
|
||||
parse_value!(deserialize_u8, visit_u8, "u8");
|
||||
parse_value!(deserialize_u16, visit_u16, "u16");
|
||||
parse_value!(deserialize_u32, visit_u32, "u32");
|
||||
parse_value!(deserialize_u64, visit_u64, "u64");
|
||||
parse_value!(deserialize_f32, visit_f32, "f32");
|
||||
parse_value!(deserialize_f64, visit_f64, "f64");
|
||||
parse_value!(deserialize_string, visit_string, "String");
|
||||
parse_value!(deserialize_byte_buf, visit_string, "String");
|
||||
parse_value!(deserialize_char, visit_char, "char");
|
||||
|
||||
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_borrowed_str(self.value)
|
||||
}
|
||||
|
||||
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_borrowed_bytes(self.value.as_bytes())
|
||||
}
|
||||
|
||||
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_some(self)
|
||||
}
|
||||
|
||||
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_unit()
|
||||
}
|
||||
|
||||
fn deserialize_unit_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_unit()
|
||||
}
|
||||
|
||||
fn deserialize_newtype_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_newtype_struct(self)
|
||||
}
|
||||
|
||||
fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
Err(PathDeserializerError::custom("unsupported type: tuple"))
|
||||
}
|
||||
|
||||
fn deserialize_tuple_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_len: usize,
|
||||
_visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
Err(PathDeserializerError::custom(
|
||||
"unsupported type: tuple struct",
|
||||
))
|
||||
}
|
||||
|
||||
fn deserialize_struct<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_fields: &'static [&'static str],
|
||||
_visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
Err(PathDeserializerError::custom("unsupported type: struct"))
|
||||
}
|
||||
|
||||
fn deserialize_enum<V>(
|
||||
self,
|
||||
_name: &'static str,
|
||||
_variants: &'static [&'static str],
|
||||
visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_enum(EnumDeserializer { value: self.value })
|
||||
}
|
||||
|
||||
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
visitor.visit_unit()
|
||||
}
|
||||
}
|
||||
|
||||
struct EnumDeserializer<'de> {
|
||||
value: &'de str,
|
||||
}
|
||||
|
||||
impl<'de> EnumAccess<'de> for EnumDeserializer<'de> {
|
||||
type Error = PathDeserializerError;
|
||||
type Variant = UnitVariant;
|
||||
|
||||
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
|
||||
where
|
||||
V: de::DeserializeSeed<'de>,
|
||||
{
|
||||
Ok((
|
||||
seed.deserialize(KeyDeserializer { key: self.value })?,
|
||||
UnitVariant,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
struct UnitVariant;
|
||||
|
||||
impl<'de> VariantAccess<'de> for UnitVariant {
|
||||
type Error = PathDeserializerError;
|
||||
|
||||
fn unit_variant(self) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value, Self::Error>
|
||||
where
|
||||
T: DeserializeSeed<'de>,
|
||||
{
|
||||
Err(PathDeserializerError::custom("not supported"))
|
||||
}
|
||||
|
||||
fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
Err(PathDeserializerError::custom("not supported"))
|
||||
}
|
||||
|
||||
fn struct_variant<V>(
|
||||
self,
|
||||
_fields: &'static [&'static str],
|
||||
_visitor: V,
|
||||
) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: Visitor<'de>,
|
||||
{
|
||||
Err(PathDeserializerError::custom("not supported"))
|
||||
}
|
||||
}
|
||||
|
||||
struct SeqDeserializer<'de> {
|
||||
params: &'de [(ByteStr, ByteStr)],
|
||||
}
|
||||
|
||||
impl<'de> SeqAccess<'de> for SeqDeserializer<'de> {
|
||||
type Error = PathDeserializerError;
|
||||
|
||||
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
|
||||
where
|
||||
T: DeserializeSeed<'de>,
|
||||
{
|
||||
match self.params.split_first() {
|
||||
Some(((_, value), tail)) => {
|
||||
self.params = tail;
|
||||
Ok(Some(seed.deserialize(ValueDeserializer { value })?))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::util::ByteStr;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Deserialize, Eq, PartialEq)]
|
||||
enum MyEnum {
|
||||
A,
|
||||
B,
|
||||
#[serde(rename = "c")]
|
||||
C,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Eq, PartialEq)]
|
||||
struct Struct {
|
||||
c: String,
|
||||
b: bool,
|
||||
a: i32,
|
||||
}
|
||||
|
||||
fn create_url_params<I, K, V>(values: I) -> UrlParams
|
||||
where
|
||||
I: IntoIterator<Item = (K, V)>,
|
||||
K: AsRef<str>,
|
||||
V: AsRef<str>,
|
||||
{
|
||||
UrlParams(
|
||||
values
|
||||
.into_iter()
|
||||
.map(|(k, v)| (ByteStr::new(k), ByteStr::new(v)))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
macro_rules! check_single_value {
|
||||
($ty:ty, $value_str:literal, $value:expr) => {
|
||||
#[allow(clippy::bool_assert_comparison)]
|
||||
{
|
||||
let url_params = create_url_params([("value", $value_str)]);
|
||||
let deserializer = PathDeserializer::new(&url_params);
|
||||
assert_eq!(<$ty>::deserialize(deserializer).unwrap(), $value);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_single_value() {
|
||||
check_single_value!(bool, "true", true);
|
||||
check_single_value!(bool, "false", false);
|
||||
check_single_value!(i8, "-123", -123);
|
||||
check_single_value!(i16, "-123", -123);
|
||||
check_single_value!(i32, "-123", -123);
|
||||
check_single_value!(i64, "-123", -123);
|
||||
check_single_value!(u8, "123", 123);
|
||||
check_single_value!(u16, "123", 123);
|
||||
check_single_value!(u32, "123", 123);
|
||||
check_single_value!(u64, "123", 123);
|
||||
check_single_value!(f32, "123", 123.0);
|
||||
check_single_value!(f64, "123", 123.0);
|
||||
check_single_value!(String, "abc", "abc");
|
||||
check_single_value!(char, "a", 'a');
|
||||
|
||||
let url_params = create_url_params([("a", "B")]);
|
||||
assert_eq!(
|
||||
MyEnum::deserialize(PathDeserializer::new(&url_params)).unwrap(),
|
||||
MyEnum::B
|
||||
);
|
||||
|
||||
let url_params = create_url_params([("a", "1"), ("b", "2")]);
|
||||
assert_eq!(
|
||||
i32::deserialize(PathDeserializer::new(&url_params)).unwrap_err(),
|
||||
PathDeserializerError::custom("wrong number of parameters: 2 expected 1".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_seq() {
|
||||
let url_params = create_url_params([("a", "1"), ("b", "true"), ("c", "abc")]);
|
||||
assert_eq!(
|
||||
<(i32, bool, String)>::deserialize(PathDeserializer::new(&url_params)).unwrap(),
|
||||
(1, true, "abc".to_string())
|
||||
);
|
||||
|
||||
#[derive(Debug, Deserialize, Eq, PartialEq)]
|
||||
struct TupleStruct(i32, bool, String);
|
||||
assert_eq!(
|
||||
TupleStruct::deserialize(PathDeserializer::new(&url_params)).unwrap(),
|
||||
TupleStruct(1, true, "abc".to_string())
|
||||
);
|
||||
|
||||
let url_params = create_url_params([("a", "1"), ("b", "2"), ("c", "3")]);
|
||||
assert_eq!(
|
||||
<Vec<i32>>::deserialize(PathDeserializer::new(&url_params)).unwrap(),
|
||||
vec![1, 2, 3]
|
||||
);
|
||||
|
||||
let url_params = create_url_params([("a", "c"), ("a", "B")]);
|
||||
assert_eq!(
|
||||
<Vec<MyEnum>>::deserialize(PathDeserializer::new(&url_params)).unwrap(),
|
||||
vec![MyEnum::C, MyEnum::B]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_struct() {
|
||||
let url_params = create_url_params([("a", "1"), ("b", "true"), ("c", "abc")]);
|
||||
assert_eq!(
|
||||
Struct::deserialize(PathDeserializer::new(&url_params)).unwrap(),
|
||||
Struct {
|
||||
c: "abc".to_string(),
|
||||
b: true,
|
||||
a: 1,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_map() {
|
||||
let url_params = create_url_params([("a", "1"), ("b", "true"), ("c", "abc")]);
|
||||
assert_eq!(
|
||||
<HashMap<String, String>>::deserialize(PathDeserializer::new(&url_params)).unwrap(),
|
||||
[("a", "1"), ("b", "true"), ("c", "abc")]
|
||||
.iter()
|
||||
.map(|(key, value)| ((*key).to_string(), (*value).to_string()))
|
||||
.collect()
|
||||
);
|
||||
}
|
||||
}
|
113
src/extract/path/mod.rs
Normal file
113
src/extract/path/mod.rs
Normal file
|
@ -0,0 +1,113 @@
|
|||
mod de;
|
||||
|
||||
use super::{rejection::*, FromRequest};
|
||||
use crate::{extract::RequestParts, routing::UrlParams};
|
||||
use async_trait::async_trait;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
/// Extractor that will get captures from the URL and parse them using [`serde`](https://crates.io/crates/serde).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use axum::{extract::Path, prelude::*};
|
||||
/// use uuid::Uuid;
|
||||
///
|
||||
/// async fn users_teams_show(
|
||||
/// Path((user_id, team_id)): Path<(Uuid, Uuid)>,
|
||||
/// ) {
|
||||
/// // ...
|
||||
/// }
|
||||
///
|
||||
/// let app = route("/users/:user_id/team/:team_id", get(users_teams_show));
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
///
|
||||
/// If the path contains only one parameter, then you can omit the tuple.
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use axum::{extract::Path, prelude::*};
|
||||
/// use uuid::Uuid;
|
||||
///
|
||||
/// async fn user_info(Path(user_id): Path<Uuid>) {
|
||||
/// // ...
|
||||
/// }
|
||||
///
|
||||
/// let app = route("/users/:user_id", get(user_info));
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
///
|
||||
/// Path segments also can be deserialized into any type that implements [serde::Deserialize](https://docs.rs/serde/1.0.127/serde/trait.Deserialize.html).
|
||||
/// Path segment labels will be matched with struct field names.
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use axum::{extract::Path, prelude::*};
|
||||
/// use serde::Deserialize;
|
||||
/// use uuid::Uuid;
|
||||
///
|
||||
/// #[derive(Deserialize)]
|
||||
/// struct Params {
|
||||
/// user_id: Uuid,
|
||||
/// team_id: Uuid,
|
||||
/// }
|
||||
///
|
||||
/// async fn users_teams_show(
|
||||
/// Path(Params { user_id, team_id }): Path<Params>,
|
||||
/// ) {
|
||||
/// // ...
|
||||
/// }
|
||||
///
|
||||
/// let app = route("/users/:user_id/team/:team_id", get(users_teams_show));
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct Path<T>(pub T);
|
||||
|
||||
impl<T> Deref for Path<T> {
|
||||
type Target = T;
|
||||
|
||||
#[inline]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DerefMut for Path<T> {
|
||||
#[inline]
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, B> FromRequest<B> for Path<T>
|
||||
where
|
||||
T: DeserializeOwned + Send,
|
||||
B: Send,
|
||||
{
|
||||
type Rejection = PathParamsRejection;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
const EMPTY_URL_PARAMS: &UrlParams = &UrlParams(Vec::new());
|
||||
|
||||
let url_params = if let Some(params) = req
|
||||
.extensions_mut()
|
||||
.and_then(|ext| ext.get::<Option<UrlParams>>())
|
||||
{
|
||||
params.as_ref().unwrap_or(EMPTY_URL_PARAMS)
|
||||
} else {
|
||||
return Err(MissingRouteParams.into());
|
||||
};
|
||||
|
||||
T::deserialize(de::PathDeserializer::new(url_params))
|
||||
.map_err(|err| PathParamsRejection::InvalidPathParam(InvalidPathParam::new(err.0)))
|
||||
.map(Path)
|
||||
}
|
||||
}
|
|
@ -159,6 +159,25 @@ impl IntoResponse for InvalidUrlParam {
|
|||
}
|
||||
}
|
||||
|
||||
/// Rejection type for [`Path`](super::Path) if the capture route
|
||||
/// param didn't have the expected type.
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidPathParam(String);
|
||||
|
||||
impl InvalidPathParam {
|
||||
pub(super) fn new(err: impl Into<String>) -> Self {
|
||||
InvalidPathParam(err.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for InvalidPathParam {
|
||||
fn into_response(self) -> http::Response<Body> {
|
||||
let mut res = http::Response::new(Body::from(format!("Invalid URL param. {}", self.0)));
|
||||
*res.status_mut() = http::StatusCode::BAD_REQUEST;
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// Rejection type for extractors that deserialize query strings if the input
|
||||
/// couldn't be deserialized into the target type.
|
||||
#[derive(Debug)]
|
||||
|
@ -254,6 +273,17 @@ composite_rejection! {
|
|||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`Path`](super::Path).
|
||||
///
|
||||
/// Contains one variant for each way the [`Path`](super::Path) extractor
|
||||
/// can fail.
|
||||
pub enum PathParamsRejection {
|
||||
InvalidPathParam,
|
||||
MissingRouteParams,
|
||||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`Bytes`](bytes::Bytes).
|
||||
///
|
||||
|
|
14
src/lib.rs
14
src/lib.rs
|
@ -239,8 +239,8 @@
|
|||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! [`extract::UrlParams`] can be used to extract params from a dynamic URL. It
|
||||
//! is compatible with any type that implements [`std::str::FromStr`], such as
|
||||
//! [`extract::Path`] can be used to extract params from a dynamic URL. It
|
||||
//! is compatible with any type that implements [`serde::Deserialize`], such as
|
||||
//! [`Uuid`]:
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
|
@ -249,9 +249,7 @@
|
|||
//!
|
||||
//! let app = route("/users/:id", post(create_user));
|
||||
//!
|
||||
//! async fn create_user(params: extract::UrlParams<(Uuid,)>) {
|
||||
//! let user_id: Uuid = (params.0).0;
|
||||
//!
|
||||
//! async fn create_user(extract::Path(user_id): extract::Path<Uuid>) {
|
||||
//! // ...
|
||||
//! }
|
||||
//! # async {
|
||||
|
@ -259,9 +257,6 @@
|
|||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! There is also [`UrlParamsMap`](extract::UrlParamsMap) which provide a map
|
||||
//! like API for extracting URL params.
|
||||
//!
|
||||
//! You can also apply multiple extractors:
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
|
@ -284,10 +279,9 @@
|
|||
//! }
|
||||
//!
|
||||
//! async fn get_user_things(
|
||||
//! params: extract::UrlParams<(Uuid,)>,
|
||||
//! extract::Path(user_id): extract::Path<Uuid>,
|
||||
//! pagination: Option<extract::Query<Pagination>>,
|
||||
//! ) {
|
||||
//! let user_id: Uuid = (params.0).0;
|
||||
//! let pagination: Pagination = pagination.unwrap_or_default().0;
|
||||
//!
|
||||
//! // ...
|
||||
|
|
26
src/tests.rs
26
src/tests.rs
|
@ -9,6 +9,7 @@ use hyper::{Body, Server};
|
|||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
convert::Infallible,
|
||||
net::{SocketAddr, TcpListener},
|
||||
task::{Context, Poll},
|
||||
|
@ -244,20 +245,14 @@ async fn routing() {
|
|||
async fn extracting_url_params() {
|
||||
let app = route(
|
||||
"/users/:id",
|
||||
get(|params: extract::UrlParams<(i32,)>| async move {
|
||||
let (id,) = params.0;
|
||||
get(|extract::Path(id): extract::Path<i32>| async move {
|
||||
assert_eq!(id, 42);
|
||||
})
|
||||
.post(|params_map: extract::UrlParamsMap| async move {
|
||||
assert_eq!(params_map.get("id").unwrap(), "1337");
|
||||
assert_eq!(
|
||||
params_map
|
||||
.get_typed::<i32>("id")
|
||||
.expect("missing")
|
||||
.expect("failed to parse"),
|
||||
1337
|
||||
);
|
||||
}),
|
||||
.post(
|
||||
|extract::Path(params_map): extract::Path<HashMap<String, i32>>| async move {
|
||||
assert_eq!(params_map.get("id").unwrap(), &1337);
|
||||
},
|
||||
),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
@ -283,12 +278,7 @@ async fn extracting_url_params() {
|
|||
async fn extracting_url_params_multiple_times() {
|
||||
let app = route(
|
||||
"/users/:id",
|
||||
get(
|
||||
|_: extract::UrlParams<(i32,)>,
|
||||
_: extract::UrlParamsMap,
|
||||
_: extract::UrlParams<(i32,)>,
|
||||
_: extract::UrlParamsMap| async {},
|
||||
),
|
||||
get(|_: extract::Path<i32>, _: extract::Path<String>| async {}),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
async fn nesting_apps() {
|
||||
|
@ -8,23 +9,27 @@ async fn nesting_apps() {
|
|||
)
|
||||
.route(
|
||||
"/users/:id",
|
||||
get(|params: extract::UrlParamsMap| async move {
|
||||
format!(
|
||||
"{}: users#show ({})",
|
||||
params.get("version").unwrap(),
|
||||
params.get("id").unwrap()
|
||||
)
|
||||
}),
|
||||
get(
|
||||
|params: extract::Path<HashMap<String, String>>| async move {
|
||||
format!(
|
||||
"{}: users#show ({})",
|
||||
params.get("version").unwrap(),
|
||||
params.get("id").unwrap()
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.route(
|
||||
"/games/:id",
|
||||
get(|params: extract::UrlParamsMap| async move {
|
||||
format!(
|
||||
"{}: games#show ({})",
|
||||
params.get("version").unwrap(),
|
||||
params.get("id").unwrap()
|
||||
)
|
||||
}),
|
||||
get(
|
||||
|params: extract::Path<HashMap<String, String>>| async move {
|
||||
format!(
|
||||
"{}: games#show ({})",
|
||||
params.get("version").unwrap(),
|
||||
params.get("id").unwrap()
|
||||
)
|
||||
},
|
||||
),
|
||||
);
|
||||
|
||||
let app = route("/", get(|| async { "hi" })).nest("/:version/api", api_routes);
|
||||
|
|
10
src/util.rs
10
src/util.rs
|
@ -1,9 +1,19 @@
|
|||
use bytes::Bytes;
|
||||
use std::ops::Deref;
|
||||
|
||||
/// A string like type backed by `Bytes` making it cheap to clone.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct ByteStr(Bytes);
|
||||
|
||||
impl Deref for ByteStr {
|
||||
type Target = str;
|
||||
|
||||
#[inline]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl ByteStr {
|
||||
pub(crate) fn new<S>(s: S) -> Self
|
||||
where
|
||||
|
|
Loading…
Add table
Reference in a new issue