From fe8ee1c7bec04a1932a3d2759800639a3d3ef2f2 Mon Sep 17 00:00:00 2001
From: Nugraha <richiisei@gmail.com>
Date: Wed, 26 Apr 2023 00:02:19 +0700
Subject: [PATCH] chore: sync liburing queue

---
 queue.go      | 87 +++++++++++++++++++++++++++++++--------------------
 queue_test.go | 41 ------------------------
 2 files changed, 53 insertions(+), 75 deletions(-)

diff --git a/queue.go b/queue.go
index dd7c5a0..7ede4d9 100644
--- a/queue.go
+++ b/queue.go
@@ -14,11 +14,18 @@ const LIBURING_UDATA_TIMEOUT uint64 = ^uint64(0)
  * or if IORING_SQ_NEED_WAKEUP is set, so submit thread must be explicitly
  * awakened. For the latter case, we set the thread wakeup flag.
  */
-func (ring *IoUring) sq_ring_needs_enter(flags *uint32) bool {
+func (ring *IoUring) sq_ring_needs_enter(submitted uint32, flags *uint32) bool {
+	if submitted == 0 {
+		return false
+	}
 	if ring.Flags&IORING_SETUP_SQPOLL == 0 {
 		return true
 	}
 
+	/*
+	 * Ensure the kernel can see the store to the SQ tail before we read
+	 * the flags.
+	 */
 	// FIXME: Extra call - no inline asm.
 	io_uring_smp_mb()
 
@@ -43,6 +50,7 @@ type get_data struct {
 	getFlags uint32
 	sz       int32
 	arg      unsafe.Pointer
+	hasTs    bool
 }
 
 func (ring *IoUring) _io_uring_get_cqe(cqePtr **IoUringCqe, data *get_data) (err error) {
@@ -58,6 +66,11 @@ func (ring *IoUring) _io_uring_get_cqe(cqePtr **IoUringCqe, data *get_data) (err
 			break
 		}
 		if cqe != nil && data.waitNr == 0 && data.submit == 0 {
+			/*
+			 * If we already looped once, we already entererd
+			 * the kernel. Since there's nothing to submit or
+			 * wait for, don't keep retrying.
+			 */
 			if looped || !ring.cq_ring_needs_enter() {
 				err = syscall.EAGAIN
 				break
@@ -68,12 +81,19 @@ func (ring *IoUring) _io_uring_get_cqe(cqePtr **IoUringCqe, data *get_data) (err
 			flags = IORING_ENTER_GETEVENTS | data.getFlags
 			needEnter = true
 		}
-		if data.submit > 0 && ring.sq_ring_needs_enter(&flags) {
+		if data.submit > 0 && ring.sq_ring_needs_enter(data.submit, &flags) {
 			needEnter = true
 		}
 		if !needEnter {
 			break
 		}
+		if looped && data.hasTs {
+			arg := (*IoUringGeteventsArg)(data.arg)
+			if cqe == nil && arg.Ts != 0 && err == nil {
+				err = syscall.ETIME
+			}
+			break
+		}
 
 		if ring.IntFlags&INT_FLAG_REG_RING != 0 {
 			flags |= IORING_ENTER_REGISTERED_RING
@@ -159,30 +179,22 @@ done:
  */
 func (ring *IoUring) __io_uring_flush_sq() uint32 {
 	sq := &ring.Sq
-	var mask = *sq._KRingMask()
-	var ktail = *sq._KTail()
-	var toSubmit = sq.SqeTail - sq.SqeHead
+	tail := sq.SqeTail
 
-	if toSubmit < 1 {
-		goto out
+	if sq.SqeHead != tail {
+		sq.SqeHead = tail
+
+		/*
+		 * Ensure kernel sees the SQE updates before the tail update.
+		 */
+		atomic.StoreUint32(sq._KTail(), tail)
+		// if !(ring.Flags&IORING_SETUP_SQPOLL != 0) {
+		// 	IO_URING_WRITE_ONCE(*sq.ktail, tail)
+		// } else {
+		// 	io_uring_smp_store_release(sq.ktail, tail)
+		// }
 	}
 
-	/*
-	 * Fill in sqes that we have queued up, adding them to the kernel ring
-	 */
-	for ; toSubmit > 0; toSubmit-- {
-		*uint32Array_Index(sq.Array, uintptr(ktail&mask)) = sq.SqeHead & mask
-		ktail++
-		sq.SqeHead++
-	}
-
-	/*
-	 * Ensure that the kernel sees the SQE updates before it sees the tail
-	 * update.
-	 */
-	atomic.StoreUint32(sq._KTail(), ktail)
-
-out:
 	/*
 	 * This _may_ look problematic, as we're not supposed to be reading
 	 * SQ->head without acquire semantics. When we're in SQPOLL mode, the
@@ -194,7 +206,7 @@ out:
 	 * we can submit. The point is, we need to be able to deal with this
 	 * situation regardless of any perceived atomicity.
 	 */
-	return ktail - *sq._KHead()
+	return tail - *sq._KHead()
 }
 
 /*
@@ -233,6 +245,9 @@ func (ring *IoUring) io_uring_wait_cqes_new(cqePtr **IoUringCqe, waitNtr uint32,
  * handling between two threads.
  */
 func (ring *IoUring) __io_uring_submit_timeout(waitNr uint32, ts *syscall.Timespec) (ret int, err error) {
+	/*
+	 * If the SQ ring is full, we may need to submit IO first
+	 */
 	sqe := ring.io_uring_get_sqe()
 	if sqe == nil {
 		ret, err = ring.io_uring_submit()
@@ -268,7 +283,7 @@ func (ring *IoUring) io_uring_wait_cqes(cqePtr **IoUringCqe, waitNtr uint32, ts
 	return
 }
 
-func (ring *IoUring) io_uring_submit_and_wait_timeout(cqePtr **IoUringCqe, waitNtr uint32, ts *syscall.Timespec, sigmask *Sigset_t) (err error) {
+func (ring *IoUring) io_uring_submit_and_wait_timeout(cqePtr **IoUringCqe, waitNr uint32, ts *syscall.Timespec, sigmask *Sigset_t) (err error) {
 	var toSubmit int
 	if ts != nil {
 		if ring.Features&IORING_FEAT_EXT_ARG != 0 {
@@ -279,21 +294,22 @@ func (ring *IoUring) io_uring_submit_and_wait_timeout(cqePtr **IoUringCqe, waitN
 			}
 			data := &get_data{
 				submit:   ring.__io_uring_flush_sq(),
-				waitNr:   waitNtr,
+				waitNr:   waitNr,
 				getFlags: IORING_ENTER_EXT_ARG,
 				sz:       int32(unsafe.Sizeof(arg)),
 				arg:      unsafe.Pointer(&arg),
+				hasTs:    ts != nil,
 			}
 			return ring._io_uring_get_cqe(cqePtr, data)
 		}
-		toSubmit, err = ring.__io_uring_submit_timeout(waitNtr, ts)
+		toSubmit, err = ring.__io_uring_submit_timeout(waitNr, ts)
 		if err != nil {
 			return
 		}
 	} else {
 		toSubmit = int(ring.__io_uring_flush_sq())
 	}
-	err = ring.__io_uring_get_cqe(cqePtr, uint32(toSubmit), waitNtr, sigmask)
+	err = ring.__io_uring_get_cqe(cqePtr, uint32(toSubmit), waitNr, sigmask)
 	return
 }
 
@@ -329,9 +345,10 @@ func (ring *IoUring) __io_uring_submit_and_wait(waitNr uint32) (int, error) {
 
 func (ring *IoUring) __io_uring_submit(submitted uint32, waitNr uint32) (ret int, err error) {
 	var flags uint32 = 0
+	var cq_needs_enter = ring.cq_ring_needs_enter() || waitNr != 0
 
-	if ring.sq_ring_needs_enter(&flags) || waitNr != 0 {
-		if waitNr != 0 || ring.Flags&IORING_SETUP_IOPOLL != 0 {
+	if ring.sq_ring_needs_enter(submitted, &flags) || cq_needs_enter {
+		if cq_needs_enter {
 			flags |= IORING_ENTER_GETEVENTS
 		}
 		if ring.IntFlags&INT_FLAG_REG_RING != 0 {
@@ -354,6 +371,8 @@ func (ring *IoUring) io_uring_get_sqe() *IoUringSqe {
  * function multiple times before calling io_uring_submit().
  *
  * Returns a vacant sqe, or NULL if we're full.
+ *
+ * SAFETY: NO CONCURRENT ACCESS
  */
 func (ring *IoUring) _io_uring_get_sqe() (sqe *IoUringSqe) {
 	sq := &ring.Sq
@@ -397,7 +416,7 @@ func (ring *IoUring) __io_uring_peek_cqe(cqePtr **IoUringCqe, nrAvail *uint32) e
 
 		cqe = nil
 		avail = int(tail - head)
-		if avail < 1 {
+		if avail <= 0 {
 			break
 		}
 
@@ -423,10 +442,10 @@ func (ring *IoUring) __io_uring_peek_cqe(cqePtr **IoUringCqe, nrAvail *uint32) e
 	if nrAvail != nil {
 		*nrAvail = uint32(avail)
 	}
-	if err == 0 {
-		return nil
+	if err < 0 {
+		return syscall.Errno(-err)
 	}
-	return syscall.Errno(-err)
+	return nil
 }
 
 func (ring *IoUring) io_uring_cq_advance(nr uint32) {
diff --git a/queue_test.go b/queue_test.go
index 981111b..b960d97 100644
--- a/queue_test.go
+++ b/queue_test.go
@@ -226,44 +226,3 @@ func TestRingQueueSubmitSingleConsumer(t *testing.T) {
 		})
 	}
 }
-
-func TestRingQueueConcurrentEnqueue(t *testing.T) {
-	const entries = 64
-	h := testNewIoUring(t, entries, 0)
-	defer h.Close()
-
-	var wg sync.WaitGroup
-	var tabChecker sync.Map
-	wg.Add(entries)
-	for i := 0; i < entries; i++ {
-		go func(i int) {
-			defer wg.Done()
-			sqe := h.GetSqe()
-			PrepNop(sqe)
-			sqe.UserData.SetUint64(uint64(i))
-			if _, exist := tabChecker.LoadOrStore(sqe, struct{}{}); exist {
-				panic("enqueue race detect")
-			}
-		}(i)
-	}
-	// Join results before submit
-	wg.Wait()
-
-	// Ring full, this one should be nil.
-	require.Nil(t, h.GetSqe())
-
-	// Submit
-	submitted, err := h.Submit()
-	require.NoError(t, err)
-	println(submitted)
-
-	var cqe *IoUringCqe
-	for i := 0; i < entries; i++ {
-		h.WaitCqe(&cqe)
-		if _, exist := tabChecker.LoadOrStore(cqe, struct{}{}); exist {
-			panic("cqe race detect")
-		}
-		h.SeenCqe(cqe)
-	}
-
-}