mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-21 22:56:46 +01:00
Remove duplication in serving with and without graceful shutdown (#2803)
This commit is contained in:
parent
4b48f308c3
commit
5d8541de6e
3 changed files with 67 additions and 79 deletions
|
@ -966,7 +966,7 @@ async fn logging_rejections() {
|
||||||
rejection_type: String,
|
rejection_type: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
let events = capture_tracing::<RejectionEvent, _, _>(|| async {
|
let events = capture_tracing::<RejectionEvent, _>(|| async {
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/extension", get(|_: Extension<Infallible>| async {}))
|
.route("/extension", get(|_: Extension<Infallible>| async {}))
|
||||||
.route("/string", post(|_: String| async {}));
|
.route("/string", post(|_: String| async {}));
|
||||||
|
@ -987,6 +987,7 @@ async fn logging_rejections() {
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
);
|
);
|
||||||
})
|
})
|
||||||
|
.with_filter("axum::rejection=trace")
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|
|
@ -213,61 +213,8 @@ where
|
||||||
type IntoFuture = private::ServeFuture;
|
type IntoFuture = private::ServeFuture;
|
||||||
|
|
||||||
fn into_future(self) -> Self::IntoFuture {
|
fn into_future(self) -> Self::IntoFuture {
|
||||||
private::ServeFuture(Box::pin(async move {
|
self.with_graceful_shutdown(std::future::pending())
|
||||||
let Self {
|
.into_future()
|
||||||
tcp_listener,
|
|
||||||
mut make_service,
|
|
||||||
tcp_nodelay,
|
|
||||||
_marker: _,
|
|
||||||
} = self;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
|
|
||||||
Some(conn) => conn,
|
|
||||||
None => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(nodelay) = tcp_nodelay {
|
|
||||||
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
|
|
||||||
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let tcp_stream = TokioIo::new(tcp_stream);
|
|
||||||
|
|
||||||
poll_fn(|cx| make_service.poll_ready(cx))
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|err| match err {});
|
|
||||||
|
|
||||||
let tower_service = make_service
|
|
||||||
.call(IncomingStream {
|
|
||||||
tcp_stream: &tcp_stream,
|
|
||||||
remote_addr,
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|err| match err {})
|
|
||||||
.map_request(|req: Request<Incoming>| req.map(Body::new));
|
|
||||||
|
|
||||||
let hyper_service = TowerToHyperService::new(tower_service);
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
match Builder::new(TokioExecutor::new())
|
|
||||||
// upgrades needed for websockets
|
|
||||||
.serve_connection_with_upgrades(tcp_stream, hyper_service)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(()) => {}
|
|
||||||
Err(_err) => {
|
|
||||||
// This error only appears when the client doesn't send a request and
|
|
||||||
// terminate the connection.
|
|
||||||
//
|
|
||||||
// If client sends one request then terminate connection whenever, it doesn't
|
|
||||||
// appear.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,14 @@
|
||||||
use crate::util::AxumMutex;
|
use crate::util::AxumMutex;
|
||||||
use std::{future::Future, io, sync::Arc};
|
use std::{
|
||||||
|
future::{Future, IntoFuture},
|
||||||
|
io,
|
||||||
|
marker::PhantomData,
|
||||||
|
pin::Pin,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
use serde::{de::DeserializeOwned, Deserialize};
|
use serde::{de::DeserializeOwned, Deserialize};
|
||||||
|
use tracing::instrument::WithSubscriber;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
use tracing_subscriber::{filter::Targets, fmt::MakeWriter};
|
use tracing_subscriber::{filter::Targets, fmt::MakeWriter};
|
||||||
|
|
||||||
|
@ -14,36 +21,69 @@ pub(crate) struct TracingEvent<T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run an async closure and capture the tracing output it produces.
|
/// Run an async closure and capture the tracing output it produces.
|
||||||
pub(crate) async fn capture_tracing<T, F, Fut>(f: F) -> Vec<TracingEvent<T>>
|
pub(crate) fn capture_tracing<T, F>(f: F) -> CaptureTracing<T, F>
|
||||||
where
|
where
|
||||||
F: Fn() -> Fut,
|
|
||||||
Fut: Future,
|
|
||||||
T: DeserializeOwned,
|
T: DeserializeOwned,
|
||||||
{
|
{
|
||||||
let (make_writer, handle) = TestMakeWriter::new();
|
CaptureTracing {
|
||||||
|
f,
|
||||||
|
filter: None,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let subscriber = tracing_subscriber::registry().with(
|
pub(crate) struct CaptureTracing<T, F> {
|
||||||
tracing_subscriber::fmt::layer()
|
f: F,
|
||||||
.with_writer(make_writer)
|
filter: Option<Targets>,
|
||||||
.with_target(true)
|
_phantom: PhantomData<fn() -> T>,
|
||||||
.without_time()
|
}
|
||||||
.with_ansi(false)
|
|
||||||
.json()
|
|
||||||
.flatten_event(false)
|
|
||||||
.with_filter("axum=trace".parse::<Targets>().unwrap()),
|
|
||||||
);
|
|
||||||
|
|
||||||
let guard = tracing::subscriber::set_default(subscriber);
|
impl<T, F> CaptureTracing<T, F> {
|
||||||
|
pub(crate) fn with_filter(mut self, filter_string: &str) -> Self {
|
||||||
|
self.filter = Some(filter_string.parse().unwrap());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
f().await;
|
impl<T, F, Fut> IntoFuture for CaptureTracing<T, F>
|
||||||
|
where
|
||||||
|
F: Fn() -> Fut + Send + Sync + 'static,
|
||||||
|
Fut: Future + Send,
|
||||||
|
T: DeserializeOwned,
|
||||||
|
{
|
||||||
|
type Output = Vec<TracingEvent<T>>;
|
||||||
|
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
|
||||||
|
|
||||||
drop(guard);
|
fn into_future(self) -> Self::IntoFuture {
|
||||||
|
let Self { f, filter, .. } = self;
|
||||||
|
Box::pin(async move {
|
||||||
|
let (make_writer, handle) = TestMakeWriter::new();
|
||||||
|
|
||||||
handle
|
let filter = filter.unwrap_or_else(|| "axum=trace".parse().unwrap());
|
||||||
.take()
|
let subscriber = tracing_subscriber::registry().with(
|
||||||
.lines()
|
tracing_subscriber::fmt::layer()
|
||||||
.map(|line| serde_json::from_str(line).unwrap())
|
.with_writer(make_writer)
|
||||||
.collect()
|
.with_target(true)
|
||||||
|
.without_time()
|
||||||
|
.with_ansi(false)
|
||||||
|
.json()
|
||||||
|
.flatten_event(false)
|
||||||
|
.with_filter(filter),
|
||||||
|
);
|
||||||
|
|
||||||
|
let guard = tracing::subscriber::set_default(subscriber);
|
||||||
|
|
||||||
|
f().with_current_subscriber().await;
|
||||||
|
|
||||||
|
drop(guard);
|
||||||
|
|
||||||
|
handle
|
||||||
|
.take()
|
||||||
|
.lines()
|
||||||
|
.map(|line| serde_json::from_str(line).unwrap())
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TestMakeWriter {
|
struct TestMakeWriter {
|
||||||
|
|
Loading…
Reference in a new issue