Routing with dynamic parts!

This commit is contained in:
David Pedersen 2021-05-30 15:44:26 +02:00
parent 7328127a3d
commit 763d4e8d21
5 changed files with 182 additions and 30 deletions

View file

@ -1,3 +1,5 @@
#![allow(warnings)]
use bytes::Bytes;
use http::{Request, StatusCode};
use hyper::Server;
@ -20,9 +22,8 @@ async fn main() {
// build our application with some routes
let app = tower_web::app()
.at("/get")
.at("/:key")
.get(get)
.at("/set")
.post(set)
// convert it into a `Service`
.into_service();
@ -49,41 +50,36 @@ struct State {
db: HashMap<String, Bytes>,
}
#[derive(Deserialize)]
struct GetSetQueryString {
key: String,
}
async fn get(
_req: Request<Body>,
query: extract::Query<GetSetQueryString>,
params: extract::UrlParams,
state: extract::Extension<SharedState>,
) -> Result<Bytes, Error> {
let state = state.into_inner();
let db = &state.lock().unwrap().db;
let key = query.into_inner().key;
let key = params.get("key")?;
if let Some(value) = db.get(&key) {
if let Some(value) = db.get(key) {
Ok(value.clone())
} else {
Err(Error::WithStatus(StatusCode::NOT_FOUND))
Err(Error::Status(StatusCode::NOT_FOUND))
}
}
async fn set(
_req: Request<Body>,
query: extract::Query<GetSetQueryString>,
params: extract::UrlParams,
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
state: extract::Extension<SharedState>,
) -> Result<response::Empty, Error> {
let state = state.into_inner();
let db = &mut state.lock().unwrap().db;
let key = query.into_inner().key;
let key = params.get("key")?;
let value = value.into_inner();
db.insert(key, value);
db.insert(key.to_string(), value);
Ok(response::Empty)
}

View file

@ -37,7 +37,10 @@ pub enum Error {
PayloadTooLarge,
#[error("response failed with status {0}")]
WithStatus(StatusCode),
Status(StatusCode),
#[error("unknown URL param `{0}`")]
UnknownUrlParam(String),
}
impl From<Infallible> for Error {
@ -64,14 +67,14 @@ where
| Error::QueryStringMissing
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
Error::WithStatus(status) => make_response(status),
Error::Status(status) => make_response(status),
Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED),
Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE),
Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => {
make_response(StatusCode::INTERNAL_SERVER_ERROR)
}
Error::MissingExtension { .. }
| Error::SerializeResponseBody(_)
| Error::UnknownUrlParam(_) => make_response(StatusCode::INTERNAL_SERVER_ERROR),
Error::Service(err) => match err.downcast::<Error>() {
Ok(err) => Err(*err),

View file

@ -6,8 +6,10 @@ use http_body::Body as _;
use pin_project::pin_project;
use serde::de::DeserializeOwned;
use std::{
collections::HashMap,
future::Future,
pin::Pin,
str::FromStr,
task::{Context, Poll},
};
@ -181,3 +183,31 @@ impl<const N: u64> FromRequest for BytesMaxLength<N> {
})
}
}
pub struct UrlParams(HashMap<String, String>);
impl UrlParams {
pub fn get(&self, key: &str) -> Result<&str, Error> {
if let Some(value) = self.0.get(key) {
Ok(value)
} else {
Err(Error::UnknownUrlParam(key.to_string()))
}
}
}
impl FromRequest for UrlParams {
type Future = future::Ready<Result<Self, Error>>;
fn from_request(req: &mut Request<Body>) -> Self::Future {
if let Some(params) = req
.extensions_mut()
.get_mut::<Option<crate::routing::UrlParams>>()
{
let params = params.take().expect("params already taken").0;
future::ok(Self(params.into_iter().collect()))
} else {
panic!("no url params found for matched route. This is a bug in tower-web")
}
}
}

View file

@ -6,8 +6,6 @@ Improvements to make:
Support extracting headers, perhaps via `headers::Header`?
Actual routing
Improve compile times with lots of routes, can we box and combine routers?
Tests

View file

@ -84,13 +84,15 @@ impl<R> RouteAt<R> {
}
fn add_route_service<S>(self, service: S, method: Method) -> RouteBuilder<Route<S, R>> {
assert!(
self.route_spec.starts_with(b"/"),
"route spec must start with a slash (`/`)"
);
let new_app = App {
router: Route {
service,
route_spec: RouteSpec {
method,
spec: self.route_spec.clone(),
},
route_spec: RouteSpec::new(method, self.route_spec.clone()),
fallback: self.app.router,
handler_ready: false,
fallback_ready: false,
@ -196,9 +198,47 @@ struct RouteSpec {
}
impl RouteSpec {
fn matches<B>(&self, req: &Request<B>) -> bool {
// TODO(david): support dynamic placeholders like `/users/:id`
req.method() == self.method && req.uri().path().as_bytes() == self.spec
fn new(method: Method, spec: impl Into<Bytes>) -> Self {
Self {
method,
spec: spec.into(),
}
}
}
impl RouteSpec {
fn matches<B>(&self, req: &Request<B>) -> Option<Vec<(String, String)>> {
if req.method() != self.method {
return None;
}
let path = req.uri().path().as_bytes();
let path_parts = path.split(|b| *b == b'/');
let spec_parts = self.spec.split(|b| *b == b'/');
if spec_parts.clone().count() != path_parts.clone().count() {
return None;
}
let mut params = Vec::new();
spec_parts
.zip(path_parts)
.all(|(spec, path)| {
if let Some(key) = spec.strip_prefix(b":") {
let key = std::str::from_utf8(key).unwrap().to_string();
if let Ok(value) = std::str::from_utf8(path) {
params.push((key, value.to_string()));
true
} else {
false
}
} else {
spec == path
}
})
.then(|| params)
}
}
@ -236,8 +276,8 @@ where
}
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
if self.route_spec.matches(&req) {
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
if let Some(params) = self.route_spec.matches(&req) {
assert!(
self.handler_ready,
"handler not ready. Did you forget to call `poll_ready`?"
@ -245,6 +285,8 @@ where
self.handler_ready = false;
req.extensions_mut().insert(Some(UrlParams(params)));
future::Either::Left(BoxResponseBody(self.service.call(req)))
} else {
assert!(
@ -260,6 +302,8 @@ where
}
}
pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>);
#[pin_project]
pub struct BoxResponseBody<F>(#[pin] F);
@ -282,3 +326,84 @@ where
Poll::Ready(Ok(response))
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn test_routing() {
assert_match((Method::GET, "/"), (Method::GET, "/"));
refute_match((Method::GET, "/"), (Method::POST, "/"));
refute_match((Method::POST, "/"), (Method::GET, "/"));
assert_match((Method::GET, "/foo"), (Method::GET, "/foo"));
assert_match((Method::GET, "/foo/"), (Method::GET, "/foo/"));
refute_match((Method::GET, "/foo"), (Method::GET, "/foo/"));
refute_match((Method::GET, "/foo/"), (Method::GET, "/foo"));
assert_match((Method::GET, "/foo/bar"), (Method::GET, "/foo/bar"));
refute_match((Method::GET, "/foo/bar/"), (Method::GET, "/foo/bar"));
refute_match((Method::GET, "/foo/bar"), (Method::GET, "/foo/bar/"));
assert_match((Method::GET, "/:value"), (Method::GET, "/foo"));
assert_match((Method::GET, "/users/:id"), (Method::GET, "/users/1"));
assert_match(
(Method::GET, "/users/:id/action"),
(Method::GET, "/users/42/action"),
);
refute_match(
(Method::GET, "/users/:id/action"),
(Method::GET, "/users/42"),
);
refute_match(
(Method::GET, "/users/:id"),
(Method::GET, "/users/42/action"),
);
}
fn assert_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) {
let route = RouteSpec::new(route_spec.0.clone(), route_spec.1);
let req = Request::builder()
.method(req_spec.0.clone())
.uri(req_spec.1)
.body(())
.unwrap();
assert!(
route.matches(&req).is_some(),
"`{} {}` doesn't match `{} {}`",
req.method(),
req.uri().path(),
route.method,
std::str::from_utf8(&route.spec).unwrap(),
);
}
fn refute_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) {
let route = RouteSpec::new(route_spec.0.clone(), route_spec.1);
let req = Request::builder()
.method(req_spec.0.clone())
.uri(req_spec.1)
.body(())
.unwrap();
assert!(
route.matches(&req).is_none(),
"`{} {}` shouldn't match `{} {}`",
req.method(),
req.uri().path(),
route.method,
std::str::from_utf8(&route.spec).unwrap(),
);
}
fn route(method: Method, uri: &'static str) -> RouteSpec {
RouteSpec::new(method, uri)
}
fn req(method: Method, uri: &str) -> Request<()> {
Request::builder().uri(uri).method(method).body(()).unwrap()
}
}