Update to tower-http 0.4 (#1783)

This commit is contained in:
David Pedersen 2023-02-24 21:51:30 +01:00 committed by GitHub
parent 6a4825bb22
commit f726f16b6d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 131 additions and 193 deletions

View file

@ -29,7 +29,7 @@ axum = { path = "../axum", version = "0.6.0", features = ["headers"] }
futures-util = "0.3"
hyper = "0.14"
tokio = { version = "1.0", features = ["macros"] }
tower-http = { version = "0.3.4", features = ["limit"] }
tower-http = { version = "0.4", features = ["limit"] }
[package.metadata.cargo-public-api-crates]
allowed = [

View file

@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning].
# Unreleased
- **breaking:** `SpaRouter::handle_error` has been removed ([#1783])
- **breaking:** Change casing of `ProtoBuf` to `Protobuf` ([#1595])
[#1783]: https://github.com/tokio-rs/axum/pull/1783
[#1595]: https://github.com/tokio-rs/axum/pull/1595
# 0.5.0 (12. February, 2022)

View file

@ -43,7 +43,7 @@ mime = "0.3"
pin-project-lite = "0.2"
tokio = "1.19"
tower = { version = "0.4", default_features = false, features = ["util"] }
tower-http = { version = "0.3", features = ["map-response-body"] }
tower-http = { version = "0.4", features = ["map-response-body"] }
tower-layer = "0.3"
tower-service = "0.3"
@ -68,7 +68,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.71"
tokio = { version = "1.14", features = ["full"] }
tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.3", features = ["map-response-body", "timeout"] }
tower-http = { version = "0.4", features = ["map-response-body", "timeout"] }
[package.metadata.docs.rs]
all-features = true

View file

@ -1,23 +1,15 @@
use axum::{
body::{Body, HttpBody},
error_handling::HandleError,
response::IntoResponse,
routing::{get_service, Route},
Router,
};
use http::{Request, StatusCode};
use std::{
any::type_name,
convert::Infallible,
fmt,
future::{ready, Ready},
io,
marker::PhantomData,
path::{Path, PathBuf},
sync::Arc,
};
use tower_http::services::{ServeDir, ServeFile};
use tower_service::Service;
/// Router for single page applications.
///
@ -50,10 +42,9 @@ use tower_service::Service;
/// - `GET /some/other/path` will serve `index.html` since there isn't another
/// route for it
/// - `GET /api/foo` will serve the `api_foo` handler function
pub struct SpaRouter<S = (), B = Body, T = (), F = fn(io::Error) -> Ready<StatusCode>> {
pub struct SpaRouter<S = (), B = Body> {
paths: Arc<Paths>,
handle_error: F,
_marker: PhantomData<fn() -> (S, B, T)>,
_marker: PhantomData<fn() -> (S, B)>,
}
#[derive(Debug)]
@ -63,7 +54,7 @@ struct Paths {
index_file: PathBuf,
}
impl<S, B> SpaRouter<S, B, (), fn(io::Error) -> Ready<StatusCode>> {
impl<S, B> SpaRouter<S, B> {
/// Create a new `SpaRouter`.
///
/// Assets will be served at `GET /{serve_assets_at}` from the directory at `assets_dir`.
@ -80,13 +71,12 @@ impl<S, B> SpaRouter<S, B, (), fn(io::Error) -> Ready<StatusCode>> {
assets_dir: path.to_owned(),
index_file: path.join("index.html"),
}),
handle_error: |_| ready(StatusCode::INTERNAL_SERVER_ERROR),
_marker: PhantomData,
}
}
}
impl<S, B, T, F> SpaRouter<S, B, T, F> {
impl<S, B> SpaRouter<S, B> {
/// Set the path to the index file.
///
/// `path` must be relative to `assets_dir` passed to [`SpaRouter::new`].
@ -114,72 +104,27 @@ impl<S, B, T, F> SpaRouter<S, B, T, F> {
});
self
}
/// Change the function used to handle unknown IO errors.
///
/// `SpaRouter` automatically maps missing files and permission denied to
/// `404 Not Found`. The callback given here will be used for other IO errors.
///
/// See [`axum::error_handling::HandleErrorLayer`] for more details.
///
/// # Example
///
/// ```
/// use std::io;
/// use axum_extra::routing::SpaRouter;
/// use axum::{Router, http::{Method, Uri}};
///
/// let spa = SpaRouter::new("/assets", "dist").handle_error(handle_error);
///
/// async fn handle_error(method: Method, uri: Uri, err: io::Error) -> String {
/// format!("{} {} failed with {}", method, uri, err)
/// }
///
/// let app = Router::new().merge(spa);
/// # let _: Router = app;
/// ```
pub fn handle_error<T2, F2>(self, f: F2) -> SpaRouter<S, B, T2, F2> {
SpaRouter {
paths: self.paths,
handle_error: f,
_marker: PhantomData,
}
}
}
impl<S, B, F, T> From<SpaRouter<S, B, T, F>> for Router<S, B>
impl<S, B> From<SpaRouter<S, B>> for Router<S, B>
where
F: Clone + Send + Sync + 'static,
HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>,
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Future: Send,
B: HttpBody + Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
{
fn from(spa: SpaRouter<S, B, T, F>) -> Router<S, B> {
let assets_service = get_service(ServeDir::new(&spa.paths.assets_dir))
.handle_error(spa.handle_error.clone());
fn from(spa: SpaRouter<S, B>) -> Router<S, B> {
let assets_service = ServeDir::new(&spa.paths.assets_dir);
Router::new()
.nest_service(&spa.paths.assets_path, assets_service)
.fallback_service(
get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error),
)
.fallback_service(ServeFile::new(&spa.paths.index_file))
}
}
impl<B, T, F> fmt::Debug for SpaRouter<B, T, F> {
impl<B, T> fmt::Debug for SpaRouter<B, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self {
paths,
handle_error: _,
_marker,
} = self;
let Self { paths, _marker } = self;
f.debug_struct("SpaRouter")
.field("paths", &paths)
.field("handle_error", &format_args!("{}", type_name::<F>()))
.field("request_body_type", &format_args!("{}", type_name::<B>()))
.field(
"extractor_input_type",
@ -189,14 +134,10 @@ impl<B, T, F> fmt::Debug for SpaRouter<B, T, F> {
}
}
impl<B, T, F> Clone for SpaRouter<B, T, F>
where
F: Clone,
{
impl<B, T> Clone for SpaRouter<B, T> {
fn clone(&self) -> Self {
Self {
paths: self.paths.clone(),
handle_error: self.handle_error,
_marker: self._marker,
}
}
@ -206,10 +147,8 @@ where
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{
http::{Method, Uri},
routing::get,
};
use axum::routing::get;
use http::StatusCode;
#[tokio::test]
async fn basic() {
@ -253,21 +192,6 @@ mod tests {
assert_eq!(res.text().await, "<strong>Hello, World!</strong>\n");
}
// this should just compile
#[allow(dead_code)]
fn setting_error_handler() {
async fn handle_error(method: Method, uri: Uri, err: io::Error) -> (StatusCode, String) {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("{} {} failed. Error: {}", method, uri, err),
)
}
let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error);
Router::<(), Body>::new().merge(spa);
}
#[allow(dead_code)]
fn works_with_router_with_state() {
let _: Router = Router::new()

View file

@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- None.
- **changed:** Update to tower-http 0.4. axum is still compatible with tower-http 0.3
# 0.6.8 (24. February, 2023)

View file

@ -27,7 +27,38 @@ tower-log = ["tower/log"]
ws = ["tokio", "dep:tokio-tungstenite", "dep:sha1", "dep:base64"]
# Required for intra-doc links to resolve correctly
__private_docs = ["tower/full", "tower-http/full"]
__private_docs = [
"tower/full",
# all tower-http features except (de)?compression-zstd which doesn't
# build on `--target armv5te-unknown-linux-musleabi`
"tower-http/add-extension",
"tower-http/auth",
"tower-http/catch-panic",
"tower-http/compression-br",
"tower-http/compression-deflate",
"tower-http/compression-gzip",
"tower-http/cors",
"tower-http/decompression-br",
"tower-http/decompression-deflate",
"tower-http/decompression-gzip",
"tower-http/follow-redirect",
"tower-http/fs",
"tower-http/limit",
"tower-http/map-request-body",
"tower-http/map-response-body",
"tower-http/metrics",
"tower-http/normalize-path",
"tower-http/propagate-header",
"tower-http/redirect",
"tower-http/request-id",
"tower-http/sensitive-headers",
"tower-http/set-header",
"tower-http/set-status",
"tower-http/timeout",
"tower-http/trace",
"tower-http/util",
"tower-http/validate-request",
]
[dependencies]
async-trait = "0.1.43"
@ -47,7 +78,7 @@ pin-project-lite = "0.2.7"
serde = "1.0"
sync_wrapper = "0.1.1"
tower = { version = "0.4.13", default-features = false, features = ["util"] }
tower-http = { version = "0.3.0", features = ["util", "map-response-body"] }
tower-http = { version = "0.4", features = ["util", "map-response-body"] }
tower-layer = "0.3.2"
tower-service = "0.3"
@ -98,8 +129,38 @@ features = [
]
[dev-dependencies.tower-http]
version = "0.3.4"
features = ["full"]
version = "0.4"
features = [
# all tower-http features except (de)?compression-zstd which doesn't
# build on `--target armv5te-unknown-linux-musleabi`
"add-extension",
"auth",
"catch-panic",
"compression-br",
"compression-deflate",
"compression-gzip",
"cors",
"decompression-br",
"decompression-deflate",
"decompression-gzip",
"follow-redirect",
"fs",
"limit",
"map-request-body",
"map-response-body",
"metrics",
"normalize-path",
"propagate-header",
"redirect",
"request-id",
"sensitive-headers",
"set-header",
"set-status",
"timeout",
"trace",
"util",
"validate-request",
]
[package.metadata.playground]
features = [

View file

@ -17,12 +17,12 @@ use axum::{
routing::get,
Router,
};
use tower_http::auth::RequireAuthorizationLayer;
use tower_http::validate_request::ValidateRequestHeaderLayer;
let app = Router::new().route(
"/foo",
get(|| async {})
.route_layer(RequireAuthorizationLayer::bearer("password"))
.route_layer(ValidateRequestHeaderLayer::bearer("password"))
);
// `GET /foo` with a valid token will receive `200 OK`

View file

@ -17,11 +17,11 @@ use axum::{
routing::get,
Router,
};
use tower_http::auth::RequireAuthorizationLayer;
use tower_http::validate_request::ValidateRequestHeaderLayer;
let app = Router::new()
.route("/foo", get(|| async {}))
.route_layer(RequireAuthorizationLayer::bearer("password"));
.route_layer(ValidateRequestHeaderLayer::bearer("password"));
// `GET /foo` with a valid token will receive `200 OK`
// `GET /foo` with a invalid token will receive `401 Unauthorized`

View file

@ -38,17 +38,10 @@ let app = Router::new()
Ok::<_, Infallible>(res)
})
)
.route(
.route_service(
// GET `/static/Cargo.toml` goes to a service from tower-http
"/static/Cargo.toml",
get_service(ServeFile::new("Cargo.toml"))
// though we must handle any potential errors
.handle_error(|error: io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
})
ServeFile::new("Cargo.toml"),
);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();

View file

@ -1290,7 +1290,7 @@ mod tests {
use http::{header::ALLOW, HeaderMap};
use std::time::Duration;
use tower::{timeout::TimeoutLayer, Service, ServiceBuilder, ServiceExt};
use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir};
use tower_http::{services::fs::ServeDir, validate_request::ValidateRequestHeaderLayer};
#[crate::test]
async fn method_not_allowed_by_default() {
@ -1352,7 +1352,7 @@ mod tests {
async fn layer() {
let mut svc = MethodRouter::new()
.get(|| async { std::future::pending::<()>().await })
.layer(RequireAuthorizationLayer::bearer("password"));
.layer(ValidateRequestHeaderLayer::bearer("password"));
// method with route
let (status, _, _) = call(Method::GET, &mut svc).await;
@ -1367,7 +1367,7 @@ mod tests {
async fn route_layer() {
let mut svc = MethodRouter::new()
.get(|| async { std::future::pending::<()>().await })
.route_layer(RequireAuthorizationLayer::bearer("password"));
.route_layer(ValidateRequestHeaderLayer::bearer("password"));
// method with route
let (status, _, _) = call(Method::GET, &mut svc).await;
@ -1385,11 +1385,8 @@ mod tests {
// use the all the things :bomb:
get(ok)
.post(ok)
.route_layer(RequireAuthorizationLayer::bearer("password"))
.merge(
delete_service(ServeDir::new("."))
.handle_error(|_| async { StatusCode::NOT_FOUND }),
)
.route_layer(ValidateRequestHeaderLayer::bearer("password"))
.merge(delete_service(ServeDir::new(".")))
.fallback(|| async { StatusCode::NOT_FOUND })
.put(ok)
.layer(

View file

@ -372,7 +372,7 @@ async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() {
async fn middleware_that_return_early() {
let private = Router::new()
.route("/", get(|| async {}))
.layer(RequireAuthorizationLayer::bearer("password"));
.layer(ValidateRequestHeaderLayer::bearer("password"));
let public = Router::new().route("/public", get(|| async {}));

View file

@ -20,7 +20,7 @@ use std::{
time::Duration,
};
use tower::{service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder};
use tower_http::{auth::RequireAuthorizationLayer, limit::RequestBodyLimitLayer};
use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer};
use tower_service::Service;
mod fallback;
@ -458,7 +458,7 @@ async fn routing_to_router_panics() {
async fn route_layer() {
let app = Router::new()
.route("/foo", get(|| async {}))
.route_layer(RequireAuthorizationLayer::bearer("password"));
.route_layer(ValidateRequestHeaderLayer::bearer("password"));
let client = TestClient::new(app);

View file

@ -202,15 +202,7 @@ async fn nested_service_sees_stripped_uri() {
#[crate::test]
async fn nest_static_file_server() {
let app = Router::new().nest_service(
"/static",
get_service(ServeDir::new(".")).handle_error(|error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {error}"),
)
}),
);
let app = Router::new().nest_service("/static", ServeDir::new("."));
let client = TestClient::new(app);

View file

@ -9,6 +9,6 @@ axum = { path = "../../axum" }
hyper = "0.14"
tokio = { version = "1.0", features = ["full"] }
tower = "0.4"
tower-http = { version = "0.3", features = ["map-request-body", "util"] }
tower-http = { version = "0.4.0", features = ["map-request-body", "util"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -7,4 +7,4 @@ publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.3.0", features = ["cors"] }
tower-http = { version = "0.4.0", features = ["cors"] }

View file

@ -8,7 +8,7 @@ publish = false
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.4", features = ["util", "timeout", "load-shed", "limit"] }
tower-http = { version = "0.3.0", features = [
tower-http = { version = "0.4.0", features = [
"add-extension",
"auth",
"compression-full",

View file

@ -25,8 +25,8 @@ use std::{
};
use tower::{BoxError, ServiceBuilder};
use tower_http::{
auth::RequireAuthorizationLayer, compression::CompressionLayer, limit::RequestBodyLimitLayer,
trace::TraceLayer,
compression::CompressionLayer, limit::RequestBodyLimitLayer, trace::TraceLayer,
validate_request::ValidateRequestHeaderLayer,
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@ -128,7 +128,7 @@ fn admin_routes() -> Router<SharedState> {
.route("/keys", delete(delete_all_keys))
.route("/key/:key", delete(remove_key))
// Require bearer auth for all admin routes
.layer(RequireAuthorizationLayer::bearer("secret-token"))
.layer(ValidateRequestHeaderLayer::bearer("secret-token"))
}
async fn handle_error(error: BoxError) -> impl IntoResponse {

View file

@ -7,6 +7,6 @@ publish = false
[dependencies]
axum = { path = "../../axum", features = ["multipart"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.3.0", features = ["limit", "trace"] }
tower-http = { version = "0.4.0", features = ["limit", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -10,6 +10,6 @@ futures = "0.3"
headers = "0.3"
tokio = { version = "1.0", features = ["full"] }
tokio-stream = "0.1"
tower-http = { version = "0.3.0", features = ["fs", "trace"] }
tower-http = { version = "0.4.0", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -6,9 +6,8 @@
use axum::{
extract::TypedHeader,
http::StatusCode,
response::sse::{Event, Sse},
routing::{get, get_service},
routing::get,
Router,
};
use futures::stream::{self, Stream};
@ -29,15 +28,7 @@ async fn main() {
let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets");
let static_files_service = get_service(
ServeDir::new(assets_dir).append_index_html_on_directories(true),
)
.handle_error(|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
});
let static_files_service = ServeDir::new(assets_dir).append_index_html_on_directories(true);
// build our application with a route
let app = Router::new()

View file

@ -9,6 +9,6 @@ axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra", features = ["spa"] }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.3.0", features = ["fs", "trace"] }
tower-http = { version = "0.4.0", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -8,12 +8,11 @@ use axum::{
body::Body,
handler::HandlerWithoutStateExt,
http::{Request, StatusCode},
response::IntoResponse,
routing::{get, get_service},
routing::get,
Router,
};
use axum_extra::routing::SpaRouter;
use std::{io, net::SocketAddr};
use std::net::SocketAddr;
use tower::ServiceExt;
use tower_http::{
services::{ServeDir, ServeFile},
@ -56,7 +55,7 @@ fn using_serve_dir() -> Router {
// `SpaRouter` is just a convenient wrapper around `ServeDir`
//
// You can use `ServeDir` directly to further customize your setup
let serve_dir = get_service(ServeDir::new("assets")).handle_error(handle_error);
let serve_dir = ServeDir::new("assets");
Router::new()
.route("/foo", get(|| async { "Hi from /foo" }))
@ -69,7 +68,6 @@ fn using_serve_dir_with_assets_fallback() -> Router {
// so with this `GET /assets/doesnt-exist.jpg` will return `index.html`
// rather than a 404
let serve_dir = ServeDir::new("assets").not_found_service(ServeFile::new("assets/index.html"));
let serve_dir = get_service(serve_dir).handle_error(handle_error);
Router::new()
.route("/foo", get(|| async { "Hi from /foo" }))
@ -81,7 +79,6 @@ fn using_serve_dir_only_from_root_via_fallback() -> Router {
// you can also serve the assets directly from the root (not nested under `/assets`)
// by only setting a `ServeDir` as the fallback
let serve_dir = ServeDir::new("assets").not_found_service(ServeFile::new("assets/index.html"));
let serve_dir = get_service(serve_dir).handle_error(handle_error);
Router::new()
.route("/foo", get(|| async { "Hi from /foo" }))
@ -93,13 +90,7 @@ fn using_serve_dir_with_handler_as_service() -> Router {
(StatusCode::NOT_FOUND, "Not found")
}
// you can convert handler function to service
let service = handle_404
.into_service()
.map_err(|err| -> std::io::Error { match err {} });
let serve_dir = ServeDir::new("assets").not_found_service(service);
let serve_dir = get_service(serve_dir).handle_error(handle_error);
let serve_dir = ServeDir::new("assets").not_found_service(handle_404.into_service());
Router::new()
.route("/foo", get(|| async { "Hi from /foo" }))
@ -108,8 +99,8 @@ fn using_serve_dir_with_handler_as_service() -> Router {
fn two_serve_dirs() -> Router {
// you can also have two `ServeDir`s nested at different paths
let serve_dir_from_assets = get_service(ServeDir::new("assets")).handle_error(handle_error);
let serve_dir_from_dist = get_service(ServeDir::new("dist")).handle_error(handle_error);
let serve_dir_from_assets = ServeDir::new("assets");
let serve_dir_from_dist = ServeDir::new("dist");
Router::new()
.nest_service("/assets", serve_dir_from_assets)
@ -123,17 +114,13 @@ fn calling_serve_dir_from_a_handler() -> Router {
Router::new().nest_service(
"/foo",
get(|request: Request<Body>| async {
let service = get_service(ServeDir::new("assets")).handle_error(handle_error);
let service = ServeDir::new("assets");
let result = service.oneshot(request).await;
result
}),
)
}
async fn handle_error(_err: io::Error) -> impl IntoResponse {
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong...")
}
async fn serve(app: Router, port: u16) {
let addr = SocketAddr::from(([127, 0, 0, 1], port));
tracing::debug!("listening on {}", addr);

View file

@ -10,7 +10,7 @@ hyper = { version = "0.14", features = ["full"] }
mime = "0.3"
serde_json = "1.0"
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.3.0", features = ["trace"] }
tower-http = { version = "0.4.0", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -9,7 +9,7 @@ axum = { path = "../../axum" }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.4", features = ["util", "timeout"] }
tower-http = { version = "0.3.0", features = ["add-extension", "trace"] }
tower-http = { version = "0.4.0", features = ["add-extension", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.0", features = ["serde", "v4"] }

View file

@ -7,6 +7,6 @@ publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.3.0", features = ["trace"] }
tower-http = { version = "0.4.0", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -12,7 +12,7 @@ headers = "0.3"
tokio = { version = "1.0", features = ["full"] }
tokio-tungstenite = "0.18.0"
tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.3.0", features = ["fs", "trace"] }
tower-http = { version = "0.4.0", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -23,7 +23,7 @@ use tokio_tungstenite::{
};
const N_CLIENTS: usize = 2; //set to desired number
const SERVER: &'static str = "ws://127.0.0.1:3000/ws";
const SERVER: &str = "ws://127.0.0.1:3000/ws";
#[tokio::main]
async fn main() {

View file

@ -21,9 +21,8 @@ use axum::{
ws::{Message, WebSocket, WebSocketUpgrade},
TypedHeader,
},
http::StatusCode,
response::IntoResponse,
routing::{get, get_service},
routing::get,
Router,
};
@ -58,15 +57,7 @@ async fn main() {
// build our application with some routes
let app = Router::new()
.fallback_service(
get_service(ServeDir::new(assets_dir).append_index_html_on_directories(true))
.handle_error(|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
}),
)
.fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true))
.route("/ws", get(ws_handler))
// logging so we can see whats going on
.layer(
@ -98,7 +89,7 @@ async fn ws_handler(
} else {
String::from("Unknown browser")
};
println!("`{}` at {} connected.", user_agent, addr.to_string());
println!("`{user_agent}` at {addr} connected.");
// finalize the upgrade process by returning upgrade callback.
// we can customize the callback by sending additional info such as address.
ws.on_upgrade(move |socket| handle_socket(socket, addr))
@ -107,7 +98,7 @@ async fn ws_handler(
/// Actual websocket statemachine (one will be spawned per connection)
async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
//send a ping (unsupported by some browsers) just to kick things off and get a response
if let Ok(_) = socket.send(Message::Ping(vec![1, 2, 3])).await {
if socket.send(Message::Ping(vec![1, 2, 3])).await.is_ok() {
println!("Pinged {}...", who);
} else {
println!("Could not send ping {}!", who);
@ -126,7 +117,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
return;
}
} else {
println!("client {} abruptly disconnected", who);
println!("client {who} abruptly disconnected");
return;
}
}
@ -137,11 +128,11 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
// connecting to server and receiving their greetings.
for i in 1..5 {
if socket
.send(Message::Text(String::from(format!("Hi {} times!", i))))
.send(Message::Text(format!("Hi {i} times!")))
.await
.is_err()
{
println!("client {} abruptly disconnected", who);
println!("client {who} abruptly disconnected");
return;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
@ -157,7 +148,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
for i in 0..n_msg {
// In case of any websocket error, we exit.
if sender
.send(Message::Text(format!("Server message {} ...", i)))
.send(Message::Text(format!("Server message {i} ...")))
.await
.is_err()
{
@ -167,7 +158,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
}
println!("Sending close to {}...", who);
println!("Sending close to {who}...");
if let Err(e) = sender
.send(Message::Close(Some(CloseFrame {
code: axum::extract::ws::close_code::NORMAL,