Implement tower::Layer for Extension (#801)

* Implement `tower::Layer` for `Extension`

* changelog
This commit is contained in:
David Pedersen 2022-03-01 00:39:22 +01:00 committed by GitHub
parent 0d05b5e31f
commit a2b568c7c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 51 additions and 43 deletions

View file

@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Document sharing state between handler and middleware (#783])
- **added:** `Extension<_>` can now be used in tuples for building responses, and will set an
extension on the response ([#797])
- **added:** Implement `tower::Layer` for `Extension` ([#801])
- **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>`
as field values.
- **breaking:** `sse::Event` now panics if a setter method is called twice instead of silently
@ -74,6 +75,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#791]: https://github.com/tokio-rs/axum/pull/791
[#797]: https://github.com/tokio-rs/axum/pull/797
[#800]: https://github.com/tokio-rs/axum/pull/800
[#801]: https://github.com/tokio-rs/axum/pull/801
# 0.4.4 (13. January, 2022)

View file

@ -45,8 +45,8 @@ where
/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
#[derive(Clone, Copy, Debug)]
pub struct AddExtension<S, T> {
inner: S,
value: T,
pub(crate) inner: S,
pub(crate) value: T,
}
impl<S, T> AddExtension<S, T> {

View file

@ -5,7 +5,7 @@
//! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
use super::{Extension, FromRequest, RequestParts};
use crate::{AddExtension, AddExtensionLayer};
use crate::AddExtension;
use async_trait::async_trait;
use hyper::server::conn::AddrStream;
use std::{
@ -104,7 +104,7 @@ where
fn call(&mut self, target: T) -> Self::Future {
let connect_info = ConnectInfo(C::connect_info(target));
let svc = AddExtensionLayer::new(connect_info).layer(self.svc.clone());
let svc = Extension(connect_info).layer(self.svc.clone());
ResponseFuture::new(ready(Ok(svc)))
}
}

View file

@ -12,7 +12,6 @@ use std::ops::Deref;
///
/// ```rust,no_run
/// use axum::{
/// AddExtensionLayer,
/// extract::Extension,
/// routing::get,
/// Router,
@ -33,7 +32,7 @@ use std::ops::Deref;
/// let app = Router::new().route("/", get(handler))
/// // Add middleware that inserts the state into all incoming request's
/// // extensions.
/// .layer(AddExtensionLayer::new(state));
/// .layer(Extension(state));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
@ -58,7 +57,7 @@ where
.get::<T>()
.ok_or_else(|| {
MissingExtension::from_err(format!(
"Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::AddExtensionLayer`.",
"Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::extract::Extension`.",
std::any::type_name::<T>()
))
})
@ -95,3 +94,17 @@ where
res
}
}
impl<S, T> tower_layer::Layer<S> for Extension<T>
where
T: Clone + Send + Sync + 'static,
{
type Service = crate::AddExtension<S, T>;
fn layer(&self, inner: S) -> Self::Service {
crate::AddExtension {
inner,
value: self.0.clone(),
}
}
}

View file

@ -212,9 +212,10 @@ where
mod tests {
use crate::{
body::Body,
extract::Extension,
routing::{get, post},
test_helpers::*,
AddExtensionLayer, Router,
Router,
};
use http::{Method, Request, StatusCode};
@ -247,11 +248,7 @@ mod tests {
parts.extensions.get::<Ext>().unwrap();
}
let client = TestClient::new(
Router::new()
.route("/", get(handler))
.layer(AddExtensionLayer::new(Ext)),
);
let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));
let res = client.get("/").header("x-foo", "123").send().await;
assert_eq!(res.status(), StatusCode::OK);

View file

@ -173,13 +173,11 @@
//!
//! ## Using request extensions
//!
//! The easiest way to extract state in handlers is using [`AddExtension`]
//! middleware (applied with [`AddExtensionLayer`]) and the
//! [`Extension`](crate::extract::Extension) extractor:
//! The easiest way to extract state in handlers is using [`Extension`](crate::extract::Extension)
//! as layer and extractor:
//!
//! ```rust,no_run
//! use axum::{
//! AddExtensionLayer,
//! extract::Extension,
//! routing::get,
//! Router,
@ -194,7 +192,7 @@
//!
//! let app = Router::new()
//! .route("/", get(handler))
//! .layer(AddExtensionLayer::new(shared_state));
//! .layer(Extension(shared_state));
//!
//! async fn handler(
//! Extension(state): Extension<Arc<State>>,
@ -217,7 +215,6 @@
//!
//! ```rust,no_run
//! use axum::{
//! AddExtensionLayer,
//! Json,
//! extract::{Extension, Path},
//! routing::{get, post},

View file

@ -102,11 +102,11 @@ use tower_service::Service;
/// ```rust
/// use axum::{
/// Router,
/// extract::Extension,
/// http::{Request, StatusCode},
/// routing::get,
/// response::IntoResponse,
/// middleware::{self, Next},
/// AddExtensionLayer,
/// };
/// use tower::ServiceBuilder;
///
@ -129,7 +129,7 @@ use tower_service::Service;
/// .route("/", get(|| async { /* ... */ }))
/// .layer(
/// ServiceBuilder::new()
/// .layer(AddExtensionLayer::new(state))
/// .layer(Extension(state))
/// .layer(middleware::from_fn(my_middleware)),
/// );
/// # let app: Router = app;

View file

@ -8,7 +8,7 @@ use axum::{
extract::Extension,
response::{Html, IntoResponse},
routing::get,
AddExtensionLayer, Json, Router,
Json, Router,
};
use starwars::{QueryRoot, StarWars, StarWarsSchema};
@ -28,7 +28,7 @@ async fn main() {
let app = Router::new()
.route("/", get(graphql_playground).post(graphql_handler))
.layer(AddExtensionLayer::new(schema));
.layer(Extension(schema));
println!("Playground: http://localhost:3000");

View file

@ -13,7 +13,7 @@ use axum::{
},
response::{Html, IntoResponse},
routing::get,
AddExtensionLayer, Router,
Router,
};
use futures::{sink::SinkExt, stream::StreamExt};
use std::{
@ -39,7 +39,7 @@ async fn main() {
let app = Router::new()
.route("/", get(index))
.route("/websocket", get(websocket_handler))
.layer(AddExtensionLayer::new(app_state));
.layer(Extension(app_state));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

View file

@ -13,7 +13,7 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
AddExtensionLayer, Json, Router,
Json, Router,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
@ -41,7 +41,7 @@ async fn main() {
.route("/users", post(users_create))
// Add our `user_repo` to all request's extensions so handlers can access
// it.
.layer(AddExtensionLayer::new(user_repo));
.layer(Extension(user_repo));
// Run our application
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

View file

@ -25,8 +25,7 @@ use std::{
};
use tower::{BoxError, ServiceBuilder};
use tower_http::{
add_extension::AddExtensionLayer, auth::RequireAuthorizationLayer,
compression::CompressionLayer, trace::TraceLayer,
auth::RequireAuthorizationLayer, compression::CompressionLayer, trace::TraceLayer,
};
#[tokio::main]
@ -58,7 +57,7 @@ async fn main() {
.concurrency_limit(1024)
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(SharedState::default()))
.layer(Extension(SharedState::default()))
.into_inner(),
);

View file

@ -18,7 +18,7 @@ use axum::{
http::{header::SET_COOKIE, HeaderMap},
response::{IntoResponse, Redirect, Response},
routing::get,
AddExtensionLayer, Router,
Router,
};
use http::header;
use oauth2::{
@ -49,8 +49,8 @@ async fn main() {
.route("/auth/authorized", get(login_authorized))
.route("/protected", get(protected))
.route("/logout", get(logout))
.layer(AddExtensionLayer::new(store))
.layer(AddExtensionLayer::new(oauth_client));
.layer(Extension(store))
.layer(Extension(oauth_client));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);

View file

@ -11,7 +11,7 @@ use axum::{
extract::Extension,
http::{uri::Uri, Request, Response},
routing::get,
AddExtensionLayer, Router,
Router,
};
use hyper::{client::HttpConnector, Body};
use std::{convert::TryFrom, net::SocketAddr};
@ -26,7 +26,7 @@ async fn main() {
let app = Router::new()
.route("/", get(handler))
.layer(AddExtensionLayer::new(client));
.layer(Extension(client));
let addr = SocketAddr::from(([127, 0, 0, 1], 4000));
println!("reverse proxy listening on {}", addr);

View file

@ -16,7 +16,7 @@ use axum::{
},
response::IntoResponse,
routing::get,
AddExtensionLayer, Router,
Router,
};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
@ -38,7 +38,7 @@ async fn main() {
let app = Router::new()
.route("/", get(handler))
.layer(AddExtensionLayer::new(store));
.layer(Extension(store));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);

View file

@ -18,7 +18,7 @@ use axum::{
extract::{Extension, FromRequest, RequestParts},
http::StatusCode,
routing::get,
AddExtensionLayer, Router,
Router,
};
use sqlx::postgres::{PgPool, PgPoolOptions};
@ -49,7 +49,7 @@ async fn main() {
"/",
get(using_connection_pool_extractor).post(using_connection_extractor),
)
.layer(AddExtensionLayer::new(pool));
.layer(Extension(pool));
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

View file

@ -29,7 +29,7 @@ use std::{
time::Duration,
};
use tower::{BoxError, ServiceBuilder};
use tower_http::{add_extension::AddExtensionLayer, trace::TraceLayer};
use tower_http::trace::TraceLayer;
use uuid::Uuid;
#[tokio::main]
@ -61,7 +61,7 @@ async fn main() {
}))
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(db))
.layer(Extension(db))
.into_inner(),
);

View file

@ -9,7 +9,7 @@ use axum::{
extract::{Extension, FromRequest, RequestParts},
http::StatusCode,
routing::get,
AddExtensionLayer, Router,
Router,
};
use bb8::{Pool, PooledConnection};
use bb8_postgres::PostgresConnectionManager;
@ -36,7 +36,7 @@ async fn main() {
"/",
get(using_connection_pool_extractor).post(using_connection_extractor),
)
.layer(AddExtensionLayer::new(pool));
.layer(Extension(pool));
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));