mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-21 22:56:46 +01:00
Remove duplication in serving with and without graceful shutdown (#2803)
This commit is contained in:
parent
4b48f308c3
commit
5d8541de6e
3 changed files with 67 additions and 79 deletions
|
@ -966,7 +966,7 @@ async fn logging_rejections() {
|
|||
rejection_type: String,
|
||||
}
|
||||
|
||||
let events = capture_tracing::<RejectionEvent, _, _>(|| async {
|
||||
let events = capture_tracing::<RejectionEvent, _>(|| async {
|
||||
let app = Router::new()
|
||||
.route("/extension", get(|_: Extension<Infallible>| async {}))
|
||||
.route("/string", post(|_: String| async {}));
|
||||
|
@ -987,6 +987,7 @@ async fn logging_rejections() {
|
|||
StatusCode::BAD_REQUEST,
|
||||
);
|
||||
})
|
||||
.with_filter("axum::rejection=trace")
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
|
|
|
@ -213,61 +213,8 @@ where
|
|||
type IntoFuture = private::ServeFuture;
|
||||
|
||||
fn into_future(self) -> Self::IntoFuture {
|
||||
private::ServeFuture(Box::pin(async move {
|
||||
let Self {
|
||||
tcp_listener,
|
||||
mut make_service,
|
||||
tcp_nodelay,
|
||||
_marker: _,
|
||||
} = self;
|
||||
|
||||
loop {
|
||||
let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
|
||||
Some(conn) => conn,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if let Some(nodelay) = tcp_nodelay {
|
||||
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
|
||||
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
|
||||
}
|
||||
}
|
||||
|
||||
let tcp_stream = TokioIo::new(tcp_stream);
|
||||
|
||||
poll_fn(|cx| make_service.poll_ready(cx))
|
||||
.await
|
||||
.unwrap_or_else(|err| match err {});
|
||||
|
||||
let tower_service = make_service
|
||||
.call(IncomingStream {
|
||||
tcp_stream: &tcp_stream,
|
||||
remote_addr,
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|err| match err {})
|
||||
.map_request(|req: Request<Incoming>| req.map(Body::new));
|
||||
|
||||
let hyper_service = TowerToHyperService::new(tower_service);
|
||||
|
||||
tokio::spawn(async move {
|
||||
match Builder::new(TokioExecutor::new())
|
||||
// upgrades needed for websockets
|
||||
.serve_connection_with_upgrades(tcp_stream, hyper_service)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {}
|
||||
Err(_err) => {
|
||||
// This error only appears when the client doesn't send a request and
|
||||
// terminate the connection.
|
||||
//
|
||||
// If client sends one request then terminate connection whenever, it doesn't
|
||||
// appear.
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}))
|
||||
self.with_graceful_shutdown(std::future::pending())
|
||||
.into_future()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
use crate::util::AxumMutex;
|
||||
use std::{future::Future, io, sync::Arc};
|
||||
use std::{
|
||||
future::{Future, IntoFuture},
|
||||
io,
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use tracing::instrument::WithSubscriber;
|
||||
use tracing_subscriber::prelude::*;
|
||||
use tracing_subscriber::{filter::Targets, fmt::MakeWriter};
|
||||
|
||||
|
@ -14,36 +21,69 @@ pub(crate) struct TracingEvent<T> {
|
|||
}
|
||||
|
||||
/// Run an async closure and capture the tracing output it produces.
|
||||
pub(crate) async fn capture_tracing<T, F, Fut>(f: F) -> Vec<TracingEvent<T>>
|
||||
pub(crate) fn capture_tracing<T, F>(f: F) -> CaptureTracing<T, F>
|
||||
where
|
||||
F: Fn() -> Fut,
|
||||
Fut: Future,
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
let (make_writer, handle) = TestMakeWriter::new();
|
||||
CaptureTracing {
|
||||
f,
|
||||
filter: None,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
let subscriber = tracing_subscriber::registry().with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.with_writer(make_writer)
|
||||
.with_target(true)
|
||||
.without_time()
|
||||
.with_ansi(false)
|
||||
.json()
|
||||
.flatten_event(false)
|
||||
.with_filter("axum=trace".parse::<Targets>().unwrap()),
|
||||
);
|
||||
pub(crate) struct CaptureTracing<T, F> {
|
||||
f: F,
|
||||
filter: Option<Targets>,
|
||||
_phantom: PhantomData<fn() -> T>,
|
||||
}
|
||||
|
||||
let guard = tracing::subscriber::set_default(subscriber);
|
||||
impl<T, F> CaptureTracing<T, F> {
|
||||
pub(crate) fn with_filter(mut self, filter_string: &str) -> Self {
|
||||
self.filter = Some(filter_string.parse().unwrap());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
f().await;
|
||||
impl<T, F, Fut> IntoFuture for CaptureTracing<T, F>
|
||||
where
|
||||
F: Fn() -> Fut + Send + Sync + 'static,
|
||||
Fut: Future + Send,
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
type Output = Vec<TracingEvent<T>>;
|
||||
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
|
||||
|
||||
drop(guard);
|
||||
fn into_future(self) -> Self::IntoFuture {
|
||||
let Self { f, filter, .. } = self;
|
||||
Box::pin(async move {
|
||||
let (make_writer, handle) = TestMakeWriter::new();
|
||||
|
||||
handle
|
||||
.take()
|
||||
.lines()
|
||||
.map(|line| serde_json::from_str(line).unwrap())
|
||||
.collect()
|
||||
let filter = filter.unwrap_or_else(|| "axum=trace".parse().unwrap());
|
||||
let subscriber = tracing_subscriber::registry().with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.with_writer(make_writer)
|
||||
.with_target(true)
|
||||
.without_time()
|
||||
.with_ansi(false)
|
||||
.json()
|
||||
.flatten_event(false)
|
||||
.with_filter(filter),
|
||||
);
|
||||
|
||||
let guard = tracing::subscriber::set_default(subscriber);
|
||||
|
||||
f().with_current_subscriber().await;
|
||||
|
||||
drop(guard);
|
||||
|
||||
handle
|
||||
.take()
|
||||
.lines()
|
||||
.map(|line| serde_json::from_str(line).unwrap())
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct TestMakeWriter {
|
||||
|
|
Loading…
Reference in a new issue