mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-27 00:48:44 +01:00
Routing with dynamic parts!
This commit is contained in:
parent
7328127a3d
commit
763d4e8d21
5 changed files with 182 additions and 30 deletions
|
@ -1,3 +1,5 @@
|
|||
#![allow(warnings)]
|
||||
|
||||
use bytes::Bytes;
|
||||
use http::{Request, StatusCode};
|
||||
use hyper::Server;
|
||||
|
@ -20,9 +22,8 @@ async fn main() {
|
|||
|
||||
// build our application with some routes
|
||||
let app = tower_web::app()
|
||||
.at("/get")
|
||||
.at("/:key")
|
||||
.get(get)
|
||||
.at("/set")
|
||||
.post(set)
|
||||
// convert it into a `Service`
|
||||
.into_service();
|
||||
|
@ -49,41 +50,36 @@ struct State {
|
|||
db: HashMap<String, Bytes>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct GetSetQueryString {
|
||||
key: String,
|
||||
}
|
||||
|
||||
async fn get(
|
||||
_req: Request<Body>,
|
||||
query: extract::Query<GetSetQueryString>,
|
||||
params: extract::UrlParams,
|
||||
state: extract::Extension<SharedState>,
|
||||
) -> Result<Bytes, Error> {
|
||||
let state = state.into_inner();
|
||||
let db = &state.lock().unwrap().db;
|
||||
|
||||
let key = query.into_inner().key;
|
||||
let key = params.get("key")?;
|
||||
|
||||
if let Some(value) = db.get(&key) {
|
||||
if let Some(value) = db.get(key) {
|
||||
Ok(value.clone())
|
||||
} else {
|
||||
Err(Error::WithStatus(StatusCode::NOT_FOUND))
|
||||
Err(Error::Status(StatusCode::NOT_FOUND))
|
||||
}
|
||||
}
|
||||
|
||||
async fn set(
|
||||
_req: Request<Body>,
|
||||
query: extract::Query<GetSetQueryString>,
|
||||
params: extract::UrlParams,
|
||||
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
|
||||
state: extract::Extension<SharedState>,
|
||||
) -> Result<response::Empty, Error> {
|
||||
let state = state.into_inner();
|
||||
let db = &mut state.lock().unwrap().db;
|
||||
|
||||
let key = query.into_inner().key;
|
||||
let key = params.get("key")?;
|
||||
let value = value.into_inner();
|
||||
|
||||
db.insert(key, value);
|
||||
db.insert(key.to_string(), value);
|
||||
|
||||
Ok(response::Empty)
|
||||
}
|
||||
|
|
13
src/error.rs
13
src/error.rs
|
@ -37,7 +37,10 @@ pub enum Error {
|
|||
PayloadTooLarge,
|
||||
|
||||
#[error("response failed with status {0}")]
|
||||
WithStatus(StatusCode),
|
||||
Status(StatusCode),
|
||||
|
||||
#[error("unknown URL param `{0}`")]
|
||||
UnknownUrlParam(String),
|
||||
}
|
||||
|
||||
impl From<Infallible> for Error {
|
||||
|
@ -64,14 +67,14 @@ where
|
|||
| Error::QueryStringMissing
|
||||
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
|
||||
|
||||
Error::WithStatus(status) => make_response(status),
|
||||
Error::Status(status) => make_response(status),
|
||||
|
||||
Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED),
|
||||
Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE),
|
||||
|
||||
Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => {
|
||||
make_response(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Error::MissingExtension { .. }
|
||||
| Error::SerializeResponseBody(_)
|
||||
| Error::UnknownUrlParam(_) => make_response(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
|
||||
Error::Service(err) => match err.downcast::<Error>() {
|
||||
Ok(err) => Err(*err),
|
||||
|
|
|
@ -6,8 +6,10 @@ use http_body::Body as _;
|
|||
use pin_project::pin_project;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
str::FromStr,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
|
@ -181,3 +183,31 @@ impl<const N: u64> FromRequest for BytesMaxLength<N> {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UrlParams(HashMap<String, String>);
|
||||
|
||||
impl UrlParams {
|
||||
pub fn get(&self, key: &str) -> Result<&str, Error> {
|
||||
if let Some(value) = self.0.get(key) {
|
||||
Ok(value)
|
||||
} else {
|
||||
Err(Error::UnknownUrlParam(key.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRequest for UrlParams {
|
||||
type Future = future::Ready<Result<Self, Error>>;
|
||||
|
||||
fn from_request(req: &mut Request<Body>) -> Self::Future {
|
||||
if let Some(params) = req
|
||||
.extensions_mut()
|
||||
.get_mut::<Option<crate::routing::UrlParams>>()
|
||||
{
|
||||
let params = params.take().expect("params already taken").0;
|
||||
future::ok(Self(params.into_iter().collect()))
|
||||
} else {
|
||||
panic!("no url params found for matched route. This is a bug in tower-web")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,8 +6,6 @@ Improvements to make:
|
|||
|
||||
Support extracting headers, perhaps via `headers::Header`?
|
||||
|
||||
Actual routing
|
||||
|
||||
Improve compile times with lots of routes, can we box and combine routers?
|
||||
|
||||
Tests
|
||||
|
|
143
src/routing.rs
143
src/routing.rs
|
@ -84,13 +84,15 @@ impl<R> RouteAt<R> {
|
|||
}
|
||||
|
||||
fn add_route_service<S>(self, service: S, method: Method) -> RouteBuilder<Route<S, R>> {
|
||||
assert!(
|
||||
self.route_spec.starts_with(b"/"),
|
||||
"route spec must start with a slash (`/`)"
|
||||
);
|
||||
|
||||
let new_app = App {
|
||||
router: Route {
|
||||
service,
|
||||
route_spec: RouteSpec {
|
||||
method,
|
||||
spec: self.route_spec.clone(),
|
||||
},
|
||||
route_spec: RouteSpec::new(method, self.route_spec.clone()),
|
||||
fallback: self.app.router,
|
||||
handler_ready: false,
|
||||
fallback_ready: false,
|
||||
|
@ -196,9 +198,47 @@ struct RouteSpec {
|
|||
}
|
||||
|
||||
impl RouteSpec {
|
||||
fn matches<B>(&self, req: &Request<B>) -> bool {
|
||||
// TODO(david): support dynamic placeholders like `/users/:id`
|
||||
req.method() == self.method && req.uri().path().as_bytes() == self.spec
|
||||
fn new(method: Method, spec: impl Into<Bytes>) -> Self {
|
||||
Self {
|
||||
method,
|
||||
spec: spec.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RouteSpec {
|
||||
fn matches<B>(&self, req: &Request<B>) -> Option<Vec<(String, String)>> {
|
||||
if req.method() != self.method {
|
||||
return None;
|
||||
}
|
||||
|
||||
let path = req.uri().path().as_bytes();
|
||||
let path_parts = path.split(|b| *b == b'/');
|
||||
|
||||
let spec_parts = self.spec.split(|b| *b == b'/');
|
||||
|
||||
if spec_parts.clone().count() != path_parts.clone().count() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut params = Vec::new();
|
||||
|
||||
spec_parts
|
||||
.zip(path_parts)
|
||||
.all(|(spec, path)| {
|
||||
if let Some(key) = spec.strip_prefix(b":") {
|
||||
let key = std::str::from_utf8(key).unwrap().to_string();
|
||||
if let Ok(value) = std::str::from_utf8(path) {
|
||||
params.push((key, value.to_string()));
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
spec == path
|
||||
}
|
||||
})
|
||||
.then(|| params)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -236,8 +276,8 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
if self.route_spec.matches(&req) {
|
||||
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
|
||||
if let Some(params) = self.route_spec.matches(&req) {
|
||||
assert!(
|
||||
self.handler_ready,
|
||||
"handler not ready. Did you forget to call `poll_ready`?"
|
||||
|
@ -245,6 +285,8 @@ where
|
|||
|
||||
self.handler_ready = false;
|
||||
|
||||
req.extensions_mut().insert(Some(UrlParams(params)));
|
||||
|
||||
future::Either::Left(BoxResponseBody(self.service.call(req)))
|
||||
} else {
|
||||
assert!(
|
||||
|
@ -260,6 +302,8 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>);
|
||||
|
||||
#[pin_project]
|
||||
pub struct BoxResponseBody<F>(#[pin] F);
|
||||
|
||||
|
@ -282,3 +326,84 @@ where
|
|||
Poll::Ready(Ok(response))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[allow(unused_imports)]
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_routing() {
|
||||
assert_match((Method::GET, "/"), (Method::GET, "/"));
|
||||
refute_match((Method::GET, "/"), (Method::POST, "/"));
|
||||
refute_match((Method::POST, "/"), (Method::GET, "/"));
|
||||
|
||||
assert_match((Method::GET, "/foo"), (Method::GET, "/foo"));
|
||||
assert_match((Method::GET, "/foo/"), (Method::GET, "/foo/"));
|
||||
refute_match((Method::GET, "/foo"), (Method::GET, "/foo/"));
|
||||
refute_match((Method::GET, "/foo/"), (Method::GET, "/foo"));
|
||||
|
||||
assert_match((Method::GET, "/foo/bar"), (Method::GET, "/foo/bar"));
|
||||
refute_match((Method::GET, "/foo/bar/"), (Method::GET, "/foo/bar"));
|
||||
refute_match((Method::GET, "/foo/bar"), (Method::GET, "/foo/bar/"));
|
||||
|
||||
assert_match((Method::GET, "/:value"), (Method::GET, "/foo"));
|
||||
assert_match((Method::GET, "/users/:id"), (Method::GET, "/users/1"));
|
||||
assert_match(
|
||||
(Method::GET, "/users/:id/action"),
|
||||
(Method::GET, "/users/42/action"),
|
||||
);
|
||||
refute_match(
|
||||
(Method::GET, "/users/:id/action"),
|
||||
(Method::GET, "/users/42"),
|
||||
);
|
||||
refute_match(
|
||||
(Method::GET, "/users/:id"),
|
||||
(Method::GET, "/users/42/action"),
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) {
|
||||
let route = RouteSpec::new(route_spec.0.clone(), route_spec.1);
|
||||
let req = Request::builder()
|
||||
.method(req_spec.0.clone())
|
||||
.uri(req_spec.1)
|
||||
.body(())
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
route.matches(&req).is_some(),
|
||||
"`{} {}` doesn't match `{} {}`",
|
||||
req.method(),
|
||||
req.uri().path(),
|
||||
route.method,
|
||||
std::str::from_utf8(&route.spec).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
fn refute_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) {
|
||||
let route = RouteSpec::new(route_spec.0.clone(), route_spec.1);
|
||||
let req = Request::builder()
|
||||
.method(req_spec.0.clone())
|
||||
.uri(req_spec.1)
|
||||
.body(())
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
route.matches(&req).is_none(),
|
||||
"`{} {}` shouldn't match `{} {}`",
|
||||
req.method(),
|
||||
req.uri().path(),
|
||||
route.method,
|
||||
std::str::from_utf8(&route.spec).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
fn route(method: Method, uri: &'static str) -> RouteSpec {
|
||||
RouteSpec::new(method, uri)
|
||||
}
|
||||
|
||||
fn req(method: Method, uri: &str) -> Request<()> {
|
||||
Request::builder().uri(uri).method(method).body(()).unwrap()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue