mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-28 07:20:12 +01:00
Add extension extractor
This commit is contained in:
parent
b763eaa037
commit
2e16842431
2 changed files with 84 additions and 34 deletions
|
@ -22,5 +22,5 @@ tower = { version = "0.4", features = ["util"] }
|
|||
tokio = { version = "1.6.1", features = ["macros", "rt"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tower = { version = "0.4", features = ["util", "make", "timeout"] }
|
||||
tower-http = { version = "0.1", features = ["trace", "compression"] }
|
||||
tower-http = { version = "0.1", features = ["trace", "compression", "add-extension"] }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
|
|
116
src/lib.rs
116
src/lib.rs
|
@ -205,6 +205,9 @@ pub enum Error {
|
|||
|
||||
#[error("handler service returned an error")]
|
||||
Service(#[source] BoxError),
|
||||
|
||||
#[error("request extension was not set")]
|
||||
MissingExtension { type_name: &'static str },
|
||||
}
|
||||
|
||||
impl From<Infallible> for Error {
|
||||
|
@ -463,6 +466,37 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Extension<T>(T);
|
||||
|
||||
impl<T> Extension<T> {
|
||||
pub fn into_inner(self) -> T {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FromRequest for Extension<T>
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
type Future = future::Ready<Result<Self, Error>>;
|
||||
|
||||
fn from_request(req: &mut Request<Body>) -> Self::Future {
|
||||
let result = (|| {
|
||||
let value = req
|
||||
.extensions()
|
||||
.get::<T>()
|
||||
.ok_or_else(|| Error::MissingExtension {
|
||||
type_name: std::any::type_name::<T>(),
|
||||
})
|
||||
.map(|x| x.clone())?;
|
||||
Ok(Extension(value))
|
||||
})();
|
||||
|
||||
future::ready(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct EmptyRouter(());
|
||||
|
||||
|
@ -691,6 +725,8 @@ where
|
|||
| Error::QueryStringMissing
|
||||
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
|
||||
|
||||
Error::MissingExtension { .. } => make_response(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
|
||||
Error::Service(err) => match err.downcast::<Error>() {
|
||||
Ok(err) => Err(*err),
|
||||
Err(err) => Err(Error::Service(err)),
|
||||
|
@ -707,11 +743,12 @@ mod tests {
|
|||
use super::*;
|
||||
use hyper::Server;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, net::SocketAddr};
|
||||
use std::{fmt, net::SocketAddr, sync::Arc};
|
||||
use tower::{
|
||||
layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder,
|
||||
};
|
||||
use tower_http::{
|
||||
add_extension::AddExtensionLayer,
|
||||
compression::CompressionLayer,
|
||||
trace::{Trace, TraceLayer},
|
||||
};
|
||||
|
@ -737,42 +774,55 @@ mod tests {
|
|||
Ok(Response::new(Body::empty()))
|
||||
}
|
||||
|
||||
let app = app()
|
||||
// routes with functions
|
||||
.at("/")
|
||||
.get(root)
|
||||
// routes with closures
|
||||
.at("/users")
|
||||
.get(|_: Request<Body>, pagination: Query<Pagination>| async {
|
||||
let pagination = pagination.into_inner();
|
||||
assert_eq!(pagination.page, 1);
|
||||
assert_eq!(pagination.per_page, 30);
|
||||
let app =
|
||||
app()
|
||||
// routes with functions
|
||||
.at("/")
|
||||
.get(root)
|
||||
// routes with closures
|
||||
.at("/users")
|
||||
.get(|_: Request<Body>, pagination: Query<Pagination>| async {
|
||||
let pagination = pagination.into_inner();
|
||||
assert_eq!(pagination.page, 1);
|
||||
assert_eq!(pagination.per_page, 30);
|
||||
|
||||
Ok::<_, Error>(Response::new(Body::from("users#index")))
|
||||
})
|
||||
.post(|_: Request<Body>, payload: Json<UsersCreate>| async {
|
||||
let payload = payload.into_inner();
|
||||
assert_eq!(payload.username, "bob");
|
||||
Ok::<_, Error>(Response::new(Body::from("users#index")))
|
||||
})
|
||||
.post(
|
||||
|_: Request<Body>,
|
||||
payload: Json<UsersCreate>,
|
||||
_state: Extension<Arc<State>>| async {
|
||||
let payload = payload.into_inner();
|
||||
assert_eq!(payload.username, "bob");
|
||||
|
||||
Ok::<_, Error>(Response::new(Body::from("users#create")))
|
||||
})
|
||||
// routes with a service
|
||||
.at("/service")
|
||||
.get_service(service_fn(root))
|
||||
// routes with layers applied
|
||||
.at("/large-static-file")
|
||||
.get(
|
||||
large_static_file.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(30)))
|
||||
.layer(CompressionLayer::new())
|
||||
.into_inner(),
|
||||
),
|
||||
)
|
||||
.into_service();
|
||||
Ok::<_, Error>(Response::new(Body::from("users#create")))
|
||||
},
|
||||
)
|
||||
// routes with a service
|
||||
.at("/service")
|
||||
.get_service(service_fn(root))
|
||||
// routes with layers applied
|
||||
.at("/large-static-file")
|
||||
.get(
|
||||
large_static_file.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(30)))
|
||||
.layer(CompressionLayer::new())
|
||||
.into_inner(),
|
||||
),
|
||||
)
|
||||
.into_service();
|
||||
|
||||
// state shared by all routes, could hold db connection etc
|
||||
struct State {}
|
||||
|
||||
let state = Arc::new(State {});
|
||||
|
||||
// can add more middleware
|
||||
let mut app = Trace::new_for_http(app);
|
||||
let mut app = ServiceBuilder::new()
|
||||
.layer(AddExtensionLayer::new(state))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.service(app);
|
||||
|
||||
let res = app
|
||||
.ready()
|
||||
|
|
Loading…
Reference in a new issue