Add extension extractor

This commit is contained in:
David Pedersen 2021-05-30 11:07:56 +02:00
parent b763eaa037
commit 2e16842431
2 changed files with 84 additions and 34 deletions

View file

@ -22,5 +22,5 @@ tower = { version = "0.4", features = ["util"] }
tokio = { version = "1.6.1", features = ["macros", "rt"] }
serde = { version = "1.0", features = ["derive"] }
tower = { version = "0.4", features = ["util", "make", "timeout"] }
tower-http = { version = "0.1", features = ["trace", "compression"] }
tower-http = { version = "0.1", features = ["trace", "compression", "add-extension"] }
hyper = { version = "0.14", features = ["full"] }

View file

@ -205,6 +205,9 @@ pub enum Error {
#[error("handler service returned an error")]
Service(#[source] BoxError),
#[error("request extension was not set")]
MissingExtension { type_name: &'static str },
}
impl From<Infallible> for Error {
@ -463,6 +466,37 @@ where
}
}
#[derive(Debug, Clone, Copy)]
pub struct Extension<T>(T);
impl<T> Extension<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> FromRequest for Extension<T>
where
T: Clone + Send + Sync + 'static,
{
type Future = future::Ready<Result<Self, Error>>;
fn from_request(req: &mut Request<Body>) -> Self::Future {
let result = (|| {
let value = req
.extensions()
.get::<T>()
.ok_or_else(|| Error::MissingExtension {
type_name: std::any::type_name::<T>(),
})
.map(|x| x.clone())?;
Ok(Extension(value))
})();
future::ready(result)
}
}
#[derive(Clone, Copy)]
pub struct EmptyRouter(());
@ -691,6 +725,8 @@ where
| Error::QueryStringMissing
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
Error::MissingExtension { .. } => make_response(StatusCode::INTERNAL_SERVER_ERROR),
Error::Service(err) => match err.downcast::<Error>() {
Ok(err) => Err(*err),
Err(err) => Err(Error::Service(err)),
@ -707,11 +743,12 @@ mod tests {
use super::*;
use hyper::Server;
use std::time::Duration;
use std::{fmt, net::SocketAddr};
use std::{fmt, net::SocketAddr, sync::Arc};
use tower::{
layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder,
};
use tower_http::{
add_extension::AddExtensionLayer,
compression::CompressionLayer,
trace::{Trace, TraceLayer},
};
@ -737,7 +774,8 @@ mod tests {
Ok(Response::new(Body::empty()))
}
let app = app()
let app =
app()
// routes with functions
.at("/")
.get(root)
@ -750,12 +788,16 @@ mod tests {
Ok::<_, Error>(Response::new(Body::from("users#index")))
})
.post(|_: Request<Body>, payload: Json<UsersCreate>| async {
.post(
|_: Request<Body>,
payload: Json<UsersCreate>,
_state: Extension<Arc<State>>| async {
let payload = payload.into_inner();
assert_eq!(payload.username, "bob");
Ok::<_, Error>(Response::new(Body::from("users#create")))
})
},
)
// routes with a service
.at("/service")
.get_service(service_fn(root))
@ -771,8 +813,16 @@ mod tests {
)
.into_service();
// state shared by all routes, could hold db connection etc
struct State {}
let state = Arc::new(State {});
// can add more middleware
let mut app = Trace::new_for_http(app);
let mut app = ServiceBuilder::new()
.layer(AddExtensionLayer::new(state))
.layer(TraceLayer::new_for_http())
.service(app);
let res = app
.ready()