From a2b568c7c1b3cf9dfe7da0b162c2e43c5d225fd4 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 1 Mar 2022 00:39:22 +0100 Subject: [PATCH] Implement `tower::Layer` for `Extension` (#801) * Implement `tower::Layer` for `Extension` * changelog --- axum/CHANGELOG.md | 2 ++ axum/src/add_extension.rs | 4 ++-- axum/src/extract/connect_info.rs | 4 ++-- axum/src/extract/extension.rs | 19 ++++++++++++++++--- axum/src/extract/request_parts.rs | 9 +++------ axum/src/lib.rs | 9 +++------ axum/src/middleware/from_fn.rs | 4 ++-- examples/async-graphql/src/main.rs | 4 ++-- examples/chat/src/main.rs | 4 ++-- .../src/main.rs | 4 ++-- examples/key-value-store/src/main.rs | 5 ++--- examples/oauth/src/main.rs | 6 +++--- examples/reverse-proxy/src/main.rs | 4 ++-- examples/sessions/src/main.rs | 4 ++-- examples/sqlx-postgres/src/main.rs | 4 ++-- examples/todos/src/main.rs | 4 ++-- examples/tokio-postgres/src/main.rs | 4 ++-- 17 files changed, 51 insertions(+), 43 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index eb8df458..407d7805 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Document sharing state between handler and middleware (#783]) - **added:** `Extension<_>` can now be used in tuples for building responses, and will set an extension on the response ([#797]) +- **added:** Implement `tower::Layer` for `Extension` ([#801]) - **breaking:** `sse::Event` now accepts types implementing `AsRef` instead of `Into` as field values. - **breaking:** `sse::Event` now panics if a setter method is called twice instead of silently @@ -74,6 +75,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#791]: https://github.com/tokio-rs/axum/pull/791 [#797]: https://github.com/tokio-rs/axum/pull/797 [#800]: https://github.com/tokio-rs/axum/pull/800 +[#801]: https://github.com/tokio-rs/axum/pull/801 # 0.4.4 (13. January, 2022) diff --git a/axum/src/add_extension.rs b/axum/src/add_extension.rs index 4bf88ee6..19ee8952 100644 --- a/axum/src/add_extension.rs +++ b/axum/src/add_extension.rs @@ -45,8 +45,8 @@ where /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html #[derive(Clone, Copy, Debug)] pub struct AddExtension { - inner: S, - value: T, + pub(crate) inner: S, + pub(crate) value: T, } impl AddExtension { diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 071b0690..e5622031 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -5,7 +5,7 @@ //! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info use super::{Extension, FromRequest, RequestParts}; -use crate::{AddExtension, AddExtensionLayer}; +use crate::AddExtension; use async_trait::async_trait; use hyper::server::conn::AddrStream; use std::{ @@ -104,7 +104,7 @@ where fn call(&mut self, target: T) -> Self::Future { let connect_info = ConnectInfo(C::connect_info(target)); - let svc = AddExtensionLayer::new(connect_info).layer(self.svc.clone()); + let svc = Extension(connect_info).layer(self.svc.clone()); ResponseFuture::new(ready(Ok(svc))) } } diff --git a/axum/src/extract/extension.rs b/axum/src/extract/extension.rs index 34fc297f..0a200600 100644 --- a/axum/src/extract/extension.rs +++ b/axum/src/extract/extension.rs @@ -12,7 +12,6 @@ use std::ops::Deref; /// /// ```rust,no_run /// use axum::{ -/// AddExtensionLayer, /// extract::Extension, /// routing::get, /// Router, @@ -33,7 +32,7 @@ use std::ops::Deref; /// let app = Router::new().route("/", get(handler)) /// // Add middleware that inserts the state into all incoming request's /// // extensions. -/// .layer(AddExtensionLayer::new(state)); +/// .layer(Extension(state)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; @@ -58,7 +57,7 @@ where .get::() .ok_or_else(|| { MissingExtension::from_err(format!( - "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::AddExtensionLayer`.", + "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::extract::Extension`.", std::any::type_name::() )) }) @@ -95,3 +94,17 @@ where res } } + +impl tower_layer::Layer for Extension +where + T: Clone + Send + Sync + 'static, +{ + type Service = crate::AddExtension; + + fn layer(&self, inner: S) -> Self::Service { + crate::AddExtension { + inner, + value: self.0.clone(), + } + } +} diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index 06980767..02b044c5 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -212,9 +212,10 @@ where mod tests { use crate::{ body::Body, + extract::Extension, routing::{get, post}, test_helpers::*, - AddExtensionLayer, Router, + Router, }; use http::{Method, Request, StatusCode}; @@ -247,11 +248,7 @@ mod tests { parts.extensions.get::().unwrap(); } - let client = TestClient::new( - Router::new() - .route("/", get(handler)) - .layer(AddExtensionLayer::new(Ext)), - ); + let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext))); let res = client.get("/").header("x-foo", "123").send().await; assert_eq!(res.status(), StatusCode::OK); diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 1247b900..357c1021 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -173,13 +173,11 @@ //! //! ## Using request extensions //! -//! The easiest way to extract state in handlers is using [`AddExtension`] -//! middleware (applied with [`AddExtensionLayer`]) and the -//! [`Extension`](crate::extract::Extension) extractor: +//! The easiest way to extract state in handlers is using [`Extension`](crate::extract::Extension) +//! as layer and extractor: //! //! ```rust,no_run //! use axum::{ -//! AddExtensionLayer, //! extract::Extension, //! routing::get, //! Router, @@ -194,7 +192,7 @@ //! //! let app = Router::new() //! .route("/", get(handler)) -//! .layer(AddExtensionLayer::new(shared_state)); +//! .layer(Extension(shared_state)); //! //! async fn handler( //! Extension(state): Extension>, @@ -217,7 +215,6 @@ //! //! ```rust,no_run //! use axum::{ -//! AddExtensionLayer, //! Json, //! extract::{Extension, Path}, //! routing::{get, post}, diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 8e68073f..df8d31ea 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -102,11 +102,11 @@ use tower_service::Service; /// ```rust /// use axum::{ /// Router, +/// extract::Extension, /// http::{Request, StatusCode}, /// routing::get, /// response::IntoResponse, /// middleware::{self, Next}, -/// AddExtensionLayer, /// }; /// use tower::ServiceBuilder; /// @@ -129,7 +129,7 @@ use tower_service::Service; /// .route("/", get(|| async { /* ... */ })) /// .layer( /// ServiceBuilder::new() -/// .layer(AddExtensionLayer::new(state)) +/// .layer(Extension(state)) /// .layer(middleware::from_fn(my_middleware)), /// ); /// # let app: Router = app; diff --git a/examples/async-graphql/src/main.rs b/examples/async-graphql/src/main.rs index 1e7271ef..f26f25b8 100644 --- a/examples/async-graphql/src/main.rs +++ b/examples/async-graphql/src/main.rs @@ -8,7 +8,7 @@ use axum::{ extract::Extension, response::{Html, IntoResponse}, routing::get, - AddExtensionLayer, Json, Router, + Json, Router, }; use starwars::{QueryRoot, StarWars, StarWarsSchema}; @@ -28,7 +28,7 @@ async fn main() { let app = Router::new() .route("/", get(graphql_playground).post(graphql_handler)) - .layer(AddExtensionLayer::new(schema)); + .layer(Extension(schema)); println!("Playground: http://localhost:3000"); diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index d7859536..8adbd52f 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -13,7 +13,7 @@ use axum::{ }, response::{Html, IntoResponse}, routing::get, - AddExtensionLayer, Router, + Router, }; use futures::{sink::SinkExt, stream::StreamExt}; use std::{ @@ -39,7 +39,7 @@ async fn main() { let app = Router::new() .route("/", get(index)) .route("/websocket", get(websocket_handler)) - .layer(AddExtensionLayer::new(app_state)); + .layer(Extension(app_state)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/error-handling-and-dependency-injection/src/main.rs b/examples/error-handling-and-dependency-injection/src/main.rs index 14faebce..208bb103 100644 --- a/examples/error-handling-and-dependency-injection/src/main.rs +++ b/examples/error-handling-and-dependency-injection/src/main.rs @@ -13,7 +13,7 @@ use axum::{ http::StatusCode, response::{IntoResponse, Response}, routing::{get, post}, - AddExtensionLayer, Json, Router, + Json, Router, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -41,7 +41,7 @@ async fn main() { .route("/users", post(users_create)) // Add our `user_repo` to all request's extensions so handlers can access // it. - .layer(AddExtensionLayer::new(user_repo)); + .layer(Extension(user_repo)); // Run our application let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index 66e68499..62fc7477 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -25,8 +25,7 @@ use std::{ }; use tower::{BoxError, ServiceBuilder}; use tower_http::{ - add_extension::AddExtensionLayer, auth::RequireAuthorizationLayer, - compression::CompressionLayer, trace::TraceLayer, + auth::RequireAuthorizationLayer, compression::CompressionLayer, trace::TraceLayer, }; #[tokio::main] @@ -58,7 +57,7 @@ async fn main() { .concurrency_limit(1024) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) - .layer(AddExtensionLayer::new(SharedState::default())) + .layer(Extension(SharedState::default())) .into_inner(), ); diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index b469c123..6532f8be 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -18,7 +18,7 @@ use axum::{ http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, routing::get, - AddExtensionLayer, Router, + Router, }; use http::header; use oauth2::{ @@ -49,8 +49,8 @@ async fn main() { .route("/auth/authorized", get(login_authorized)) .route("/protected", get(protected)) .route("/logout", get(logout)) - .layer(AddExtensionLayer::new(store)) - .layer(AddExtensionLayer::new(oauth_client)); + .layer(Extension(store)) + .layer(Extension(oauth_client)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); diff --git a/examples/reverse-proxy/src/main.rs b/examples/reverse-proxy/src/main.rs index aefe1118..3284f590 100644 --- a/examples/reverse-proxy/src/main.rs +++ b/examples/reverse-proxy/src/main.rs @@ -11,7 +11,7 @@ use axum::{ extract::Extension, http::{uri::Uri, Request, Response}, routing::get, - AddExtensionLayer, Router, + Router, }; use hyper::{client::HttpConnector, Body}; use std::{convert::TryFrom, net::SocketAddr}; @@ -26,7 +26,7 @@ async fn main() { let app = Router::new() .route("/", get(handler)) - .layer(AddExtensionLayer::new(client)); + .layer(Extension(client)); let addr = SocketAddr::from(([127, 0, 0, 1], 4000)); println!("reverse proxy listening on {}", addr); diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs index 7ac6a8e4..92876278 100644 --- a/examples/sessions/src/main.rs +++ b/examples/sessions/src/main.rs @@ -16,7 +16,7 @@ use axum::{ }, response::IntoResponse, routing::get, - AddExtensionLayer, Router, + Router, }; use serde::{Deserialize, Serialize}; use std::fmt::Debug; @@ -38,7 +38,7 @@ async fn main() { let app = Router::new() .route("/", get(handler)) - .layer(AddExtensionLayer::new(store)); + .layer(Extension(store)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index febd6632..9386b984 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -18,7 +18,7 @@ use axum::{ extract::{Extension, FromRequest, RequestParts}, http::StatusCode, routing::get, - AddExtensionLayer, Router, + Router, }; use sqlx::postgres::{PgPool, PgPoolOptions}; @@ -49,7 +49,7 @@ async fn main() { "/", get(using_connection_pool_extractor).post(using_connection_extractor), ) - .layer(AddExtensionLayer::new(pool)); + .layer(Extension(pool)); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index c98a6989..85f5a86c 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -29,7 +29,7 @@ use std::{ time::Duration, }; use tower::{BoxError, ServiceBuilder}; -use tower_http::{add_extension::AddExtensionLayer, trace::TraceLayer}; +use tower_http::trace::TraceLayer; use uuid::Uuid; #[tokio::main] @@ -61,7 +61,7 @@ async fn main() { })) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) - .layer(AddExtensionLayer::new(db)) + .layer(Extension(db)) .into_inner(), ); diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index f37de8bc..257f9280 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -9,7 +9,7 @@ use axum::{ extract::{Extension, FromRequest, RequestParts}, http::StatusCode, routing::get, - AddExtensionLayer, Router, + Router, }; use bb8::{Pool, PooledConnection}; use bb8_postgres::PostgresConnectionManager; @@ -36,7 +36,7 @@ async fn main() { "/", get(using_connection_pool_extractor).post(using_connection_extractor), ) - .layer(AddExtensionLayer::new(pool)); + .layer(Extension(pool)); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000));