mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-11 12:31:25 +01:00
Implement tower::Layer
for Extension
(#801)
* Implement `tower::Layer` for `Extension` * changelog
This commit is contained in:
parent
0d05b5e31f
commit
a2b568c7c1
17 changed files with 51 additions and 43 deletions
|
@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- **added:** Document sharing state between handler and middleware (#783])
|
||||
- **added:** `Extension<_>` can now be used in tuples for building responses, and will set an
|
||||
extension on the response ([#797])
|
||||
- **added:** Implement `tower::Layer` for `Extension` ([#801])
|
||||
- **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>`
|
||||
as field values.
|
||||
- **breaking:** `sse::Event` now panics if a setter method is called twice instead of silently
|
||||
|
@ -74,6 +75,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
[#791]: https://github.com/tokio-rs/axum/pull/791
|
||||
[#797]: https://github.com/tokio-rs/axum/pull/797
|
||||
[#800]: https://github.com/tokio-rs/axum/pull/800
|
||||
[#801]: https://github.com/tokio-rs/axum/pull/801
|
||||
|
||||
# 0.4.4 (13. January, 2022)
|
||||
|
||||
|
|
|
@ -45,8 +45,8 @@ where
|
|||
/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct AddExtension<S, T> {
|
||||
inner: S,
|
||||
value: T,
|
||||
pub(crate) inner: S,
|
||||
pub(crate) value: T,
|
||||
}
|
||||
|
||||
impl<S, T> AddExtension<S, T> {
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
//! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
|
||||
|
||||
use super::{Extension, FromRequest, RequestParts};
|
||||
use crate::{AddExtension, AddExtensionLayer};
|
||||
use crate::AddExtension;
|
||||
use async_trait::async_trait;
|
||||
use hyper::server::conn::AddrStream;
|
||||
use std::{
|
||||
|
@ -104,7 +104,7 @@ where
|
|||
|
||||
fn call(&mut self, target: T) -> Self::Future {
|
||||
let connect_info = ConnectInfo(C::connect_info(target));
|
||||
let svc = AddExtensionLayer::new(connect_info).layer(self.svc.clone());
|
||||
let svc = Extension(connect_info).layer(self.svc.clone());
|
||||
ResponseFuture::new(ready(Ok(svc)))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@ use std::ops::Deref;
|
|||
///
|
||||
/// ```rust,no_run
|
||||
/// use axum::{
|
||||
/// AddExtensionLayer,
|
||||
/// extract::Extension,
|
||||
/// routing::get,
|
||||
/// Router,
|
||||
|
@ -33,7 +32,7 @@ use std::ops::Deref;
|
|||
/// let app = Router::new().route("/", get(handler))
|
||||
/// // Add middleware that inserts the state into all incoming request's
|
||||
/// // extensions.
|
||||
/// .layer(AddExtensionLayer::new(state));
|
||||
/// .layer(Extension(state));
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
|
@ -58,7 +57,7 @@ where
|
|||
.get::<T>()
|
||||
.ok_or_else(|| {
|
||||
MissingExtension::from_err(format!(
|
||||
"Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::AddExtensionLayer`.",
|
||||
"Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::extract::Extension`.",
|
||||
std::any::type_name::<T>()
|
||||
))
|
||||
})
|
||||
|
@ -95,3 +94,17 @@ where
|
|||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> tower_layer::Layer<S> for Extension<T>
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
type Service = crate::AddExtension<S, T>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
crate::AddExtension {
|
||||
inner,
|
||||
value: self.0.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -212,9 +212,10 @@ where
|
|||
mod tests {
|
||||
use crate::{
|
||||
body::Body,
|
||||
extract::Extension,
|
||||
routing::{get, post},
|
||||
test_helpers::*,
|
||||
AddExtensionLayer, Router,
|
||||
Router,
|
||||
};
|
||||
use http::{Method, Request, StatusCode};
|
||||
|
||||
|
@ -247,11 +248,7 @@ mod tests {
|
|||
parts.extensions.get::<Ext>().unwrap();
|
||||
}
|
||||
|
||||
let client = TestClient::new(
|
||||
Router::new()
|
||||
.route("/", get(handler))
|
||||
.layer(AddExtensionLayer::new(Ext)),
|
||||
);
|
||||
let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));
|
||||
|
||||
let res = client.get("/").header("x-foo", "123").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
|
|
@ -173,13 +173,11 @@
|
|||
//!
|
||||
//! ## Using request extensions
|
||||
//!
|
||||
//! The easiest way to extract state in handlers is using [`AddExtension`]
|
||||
//! middleware (applied with [`AddExtensionLayer`]) and the
|
||||
//! [`Extension`](crate::extract::Extension) extractor:
|
||||
//! The easiest way to extract state in handlers is using [`Extension`](crate::extract::Extension)
|
||||
//! as layer and extractor:
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use axum::{
|
||||
//! AddExtensionLayer,
|
||||
//! extract::Extension,
|
||||
//! routing::get,
|
||||
//! Router,
|
||||
|
@ -194,7 +192,7 @@
|
|||
//!
|
||||
//! let app = Router::new()
|
||||
//! .route("/", get(handler))
|
||||
//! .layer(AddExtensionLayer::new(shared_state));
|
||||
//! .layer(Extension(shared_state));
|
||||
//!
|
||||
//! async fn handler(
|
||||
//! Extension(state): Extension<Arc<State>>,
|
||||
|
@ -217,7 +215,6 @@
|
|||
//!
|
||||
//! ```rust,no_run
|
||||
//! use axum::{
|
||||
//! AddExtensionLayer,
|
||||
//! Json,
|
||||
//! extract::{Extension, Path},
|
||||
//! routing::{get, post},
|
||||
|
|
|
@ -102,11 +102,11 @@ use tower_service::Service;
|
|||
/// ```rust
|
||||
/// use axum::{
|
||||
/// Router,
|
||||
/// extract::Extension,
|
||||
/// http::{Request, StatusCode},
|
||||
/// routing::get,
|
||||
/// response::IntoResponse,
|
||||
/// middleware::{self, Next},
|
||||
/// AddExtensionLayer,
|
||||
/// };
|
||||
/// use tower::ServiceBuilder;
|
||||
///
|
||||
|
@ -129,7 +129,7 @@ use tower_service::Service;
|
|||
/// .route("/", get(|| async { /* ... */ }))
|
||||
/// .layer(
|
||||
/// ServiceBuilder::new()
|
||||
/// .layer(AddExtensionLayer::new(state))
|
||||
/// .layer(Extension(state))
|
||||
/// .layer(middleware::from_fn(my_middleware)),
|
||||
/// );
|
||||
/// # let app: Router = app;
|
||||
|
|
|
@ -8,7 +8,7 @@ use axum::{
|
|||
extract::Extension,
|
||||
response::{Html, IntoResponse},
|
||||
routing::get,
|
||||
AddExtensionLayer, Json, Router,
|
||||
Json, Router,
|
||||
};
|
||||
use starwars::{QueryRoot, StarWars, StarWarsSchema};
|
||||
|
||||
|
@ -28,7 +28,7 @@ async fn main() {
|
|||
|
||||
let app = Router::new()
|
||||
.route("/", get(graphql_playground).post(graphql_handler))
|
||||
.layer(AddExtensionLayer::new(schema));
|
||||
.layer(Extension(schema));
|
||||
|
||||
println!("Playground: http://localhost:3000");
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ use axum::{
|
|||
},
|
||||
response::{Html, IntoResponse},
|
||||
routing::get,
|
||||
AddExtensionLayer, Router,
|
||||
Router,
|
||||
};
|
||||
use futures::{sink::SinkExt, stream::StreamExt};
|
||||
use std::{
|
||||
|
@ -39,7 +39,7 @@ async fn main() {
|
|||
let app = Router::new()
|
||||
.route("/", get(index))
|
||||
.route("/websocket", get(websocket_handler))
|
||||
.layer(AddExtensionLayer::new(app_state));
|
||||
.layer(Extension(app_state));
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ use axum::{
|
|||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
AddExtensionLayer, Json, Router,
|
||||
Json, Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
@ -41,7 +41,7 @@ async fn main() {
|
|||
.route("/users", post(users_create))
|
||||
// Add our `user_repo` to all request's extensions so handlers can access
|
||||
// it.
|
||||
.layer(AddExtensionLayer::new(user_repo));
|
||||
.layer(Extension(user_repo));
|
||||
|
||||
// Run our application
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
|
|
|
@ -25,8 +25,7 @@ use std::{
|
|||
};
|
||||
use tower::{BoxError, ServiceBuilder};
|
||||
use tower_http::{
|
||||
add_extension::AddExtensionLayer, auth::RequireAuthorizationLayer,
|
||||
compression::CompressionLayer, trace::TraceLayer,
|
||||
auth::RequireAuthorizationLayer, compression::CompressionLayer, trace::TraceLayer,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
|
@ -58,7 +57,7 @@ async fn main() {
|
|||
.concurrency_limit(1024)
|
||||
.timeout(Duration::from_secs(10))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(AddExtensionLayer::new(SharedState::default()))
|
||||
.layer(Extension(SharedState::default()))
|
||||
.into_inner(),
|
||||
);
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ use axum::{
|
|||
http::{header::SET_COOKIE, HeaderMap},
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
routing::get,
|
||||
AddExtensionLayer, Router,
|
||||
Router,
|
||||
};
|
||||
use http::header;
|
||||
use oauth2::{
|
||||
|
@ -49,8 +49,8 @@ async fn main() {
|
|||
.route("/auth/authorized", get(login_authorized))
|
||||
.route("/protected", get(protected))
|
||||
.route("/logout", get(logout))
|
||||
.layer(AddExtensionLayer::new(store))
|
||||
.layer(AddExtensionLayer::new(oauth_client));
|
||||
.layer(Extension(store))
|
||||
.layer(Extension(oauth_client));
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
tracing::debug!("listening on {}", addr);
|
||||
|
|
|
@ -11,7 +11,7 @@ use axum::{
|
|||
extract::Extension,
|
||||
http::{uri::Uri, Request, Response},
|
||||
routing::get,
|
||||
AddExtensionLayer, Router,
|
||||
Router,
|
||||
};
|
||||
use hyper::{client::HttpConnector, Body};
|
||||
use std::{convert::TryFrom, net::SocketAddr};
|
||||
|
@ -26,7 +26,7 @@ async fn main() {
|
|||
|
||||
let app = Router::new()
|
||||
.route("/", get(handler))
|
||||
.layer(AddExtensionLayer::new(client));
|
||||
.layer(Extension(client));
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 4000));
|
||||
println!("reverse proxy listening on {}", addr);
|
||||
|
|
|
@ -16,7 +16,7 @@ use axum::{
|
|||
},
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
AddExtensionLayer, Router,
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
|
@ -38,7 +38,7 @@ async fn main() {
|
|||
|
||||
let app = Router::new()
|
||||
.route("/", get(handler))
|
||||
.layer(AddExtensionLayer::new(store));
|
||||
.layer(Extension(store));
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
tracing::debug!("listening on {}", addr);
|
||||
|
|
|
@ -18,7 +18,7 @@ use axum::{
|
|||
extract::{Extension, FromRequest, RequestParts},
|
||||
http::StatusCode,
|
||||
routing::get,
|
||||
AddExtensionLayer, Router,
|
||||
Router,
|
||||
};
|
||||
use sqlx::postgres::{PgPool, PgPoolOptions};
|
||||
|
||||
|
@ -49,7 +49,7 @@ async fn main() {
|
|||
"/",
|
||||
get(using_connection_pool_extractor).post(using_connection_extractor),
|
||||
)
|
||||
.layer(AddExtensionLayer::new(pool));
|
||||
.layer(Extension(pool));
|
||||
|
||||
// run it with hyper
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
|
|
|
@ -29,7 +29,7 @@ use std::{
|
|||
time::Duration,
|
||||
};
|
||||
use tower::{BoxError, ServiceBuilder};
|
||||
use tower_http::{add_extension::AddExtensionLayer, trace::TraceLayer};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::main]
|
||||
|
@ -61,7 +61,7 @@ async fn main() {
|
|||
}))
|
||||
.timeout(Duration::from_secs(10))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(AddExtensionLayer::new(db))
|
||||
.layer(Extension(db))
|
||||
.into_inner(),
|
||||
);
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ use axum::{
|
|||
extract::{Extension, FromRequest, RequestParts},
|
||||
http::StatusCode,
|
||||
routing::get,
|
||||
AddExtensionLayer, Router,
|
||||
Router,
|
||||
};
|
||||
use bb8::{Pool, PooledConnection};
|
||||
use bb8_postgres::PostgresConnectionManager;
|
||||
|
@ -36,7 +36,7 @@ async fn main() {
|
|||
"/",
|
||||
get(using_connection_pool_extractor).post(using_connection_extractor),
|
||||
)
|
||||
.layer(AddExtensionLayer::new(pool));
|
||||
.layer(Extension(pool));
|
||||
|
||||
// run it with hyper
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
|
|
Loading…
Reference in a new issue