From 64b72d47c8984ca9f42d43ddef268b1a9e2eef17 Mon Sep 17 00:00:00 2001
From: MastahSenpai <26342994+ii64@users.noreply.github.com>
Date: Sun, 26 Dec 2021 18:04:46 +0700
Subject: [PATCH] feat(queue): lock-free, and thread-safe

Signed-off-by: MastahSenpai <26342994+ii64@users.noreply.github.com>
---
 queue/queue.go           | 28 +++++++-------
 queue/queue_lock.go      | 54 +++++++++++++++++++++++++++
 queue/queue_lock_test.go | 79 ++++++++++++++++++++++++++++++++++++++++
 queue/queue_test.go      | 70 ++---------------------------------
 4 files changed, 151 insertions(+), 80 deletions(-)
 create mode 100644 queue/queue_lock.go
 create mode 100644 queue/queue_lock_test.go

diff --git a/queue/queue.go b/queue/queue.go
index 5a05420..08eda32 100644
--- a/queue/queue.go
+++ b/queue/queue.go
@@ -6,7 +6,6 @@ package queue
 import (
 	"errors"
 	"runtime"
-	"sync"
 	"sync/atomic"
 	"syscall"
 
@@ -17,6 +16,8 @@ var (
 	ErrQueueClosed = errors.New("queue closed")
 )
 
+type QueueCQEHandler func(cqe *gouring.CQEntry) (err error)
+
 type Queue struct {
 	ring *gouring.Ring
 	sq   *gouring.SQRing
@@ -25,9 +26,6 @@ type Queue struct {
 	sqeHead uint32
 	sqeTail uint32
 
-	sMx  sync.Mutex
-	cqMx sync.Mutex
-
 	err error
 
 	clq uint32
@@ -115,8 +113,8 @@ func (q *Queue) isNeedEnter(flags *uint32) bool {
 }
 
 func (q *Queue) Submit() (ret int, err error) {
-	q.sMx.Lock()
-	defer q.sMx.Unlock()
+	// q.sMx.Lock()
+	// defer q.sMx.Unlock()
 	submitted := q.sqFlush()
 
 	var flags uint32
@@ -148,9 +146,9 @@ func (q *Queue) cqAdvance(d uint32) {
 	}
 }
 
-func (q *Queue) GetCQEvent(wait bool) (cqe *gouring.CQEntry, err error) {
-	q.cqMx.Lock()
-	defer q.cqMx.Unlock()
+func (q *Queue) GetCQEntry(wait bool) (cqe *gouring.CQEntry, err error) {
+	// q.cqMx.Lock()
+	// defer q.cqMx.Unlock()
 	if err = q.precheck(); err != nil {
 		return
 	}
@@ -187,17 +185,21 @@ func (q *Queue) Err() error {
 	return q.err
 }
 
-func (q *Queue) Run(f func(cqe *gouring.CQEntry)) {
+func (q *Queue) Run(wait bool, f QueueCQEHandler) (err error) {
 	for q.precheck() == nil {
-		cqe, err := q.GetCQEvent(true)
+		cqe, err := q.GetCQEntry(wait)
 		if cqe == nil || err != nil {
 			q.err = err
 			if err == ErrQueueClosed {
-				return
+				return err
 			}
 			continue
 		}
 
-		f(cqe)
+		err = f(cqe)
+		if err != nil {
+			return err
+		}
 	}
+	return nil
 }
diff --git a/queue/queue_lock.go b/queue/queue_lock.go
new file mode 100644
index 0000000..2551f7b
--- /dev/null
+++ b/queue/queue_lock.go
@@ -0,0 +1,54 @@
+package queue
+
+import (
+	"sync"
+
+	"github.com/ii64/gouring"
+)
+
+type QueueLocks struct {
+	*Queue
+
+	sMx  sync.Mutex
+	cqMx sync.Mutex
+}
+
+func NewWithLocks(ring *gouring.Ring) *QueueLocks {
+	q := &QueueLocks{
+		Queue: New(ring),
+	}
+	return q
+}
+
+func (q *QueueLocks) Submit() (ret int, err error) {
+	q.sMx.Lock()
+	defer q.sMx.Unlock()
+	return q.Queue.Submit()
+}
+
+//
+
+func (q *QueueLocks) GetCQEntry(wait bool) (cqe *gouring.CQEntry, err error) {
+	q.cqMx.Lock()
+	defer q.cqMx.Unlock()
+	return q.Queue.GetCQEntry(wait)
+}
+
+func (q *QueueLocks) Run(wait bool, f QueueCQEHandler) (err error) {
+	for q.precheck() == nil {
+		cqe, err := q.GetCQEntry(wait)
+		if cqe == nil || err != nil {
+			q.err = err
+			if err == ErrQueueClosed {
+				return err
+			}
+			continue
+		}
+
+		err = f(cqe)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
diff --git a/queue/queue_lock_test.go b/queue/queue_lock_test.go
new file mode 100644
index 0000000..879cb18
--- /dev/null
+++ b/queue/queue_lock_test.go
@@ -0,0 +1,79 @@
+package queue
+
+import (
+	"fmt"
+	"sync"
+	"syscall"
+	"testing"
+
+	"github.com/ii64/gouring"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestQueueMultiConsumer(t *testing.T) {
+	ring, err := gouring.New(256, nil)
+	assert.NoError(t, err, "create ring")
+	defer func() {
+		err := ring.Close()
+		assert.NoError(t, err, "close ring")
+	}()
+
+	N := 64 + 64
+	var wg sync.WaitGroup
+	btests := [][]byte{}
+	for i := 0; i < N; i++ {
+		btests = append(btests, mkdata(i))
+	}
+	wg.Add(N)
+
+	// create new queue
+	q := NewWithLocks(ring)
+	go func() {
+		for i, b := range btests {
+			sqe := q.GetSQEntry()
+			sqe.UserData = uint64(i)
+			// sqe.Flags = gouring.IOSQE_IO_DRAIN
+			write(sqe, syscall.Stdout, b)
+			if (i+1)%2 == 0 {
+				n, err := q.Submit()
+				assert.NoError(t, err, "queue submit")
+				assert.Equal(t, n, 2, "submit count mismatch")
+				fmt.Printf("submitted %d\n", n)
+			}
+		}
+	}()
+
+	// for i := 0; i < consumerNum; i++ {
+	// 	go func(i int) {
+	// 		fmt.Printf("wrk(%d): start.\n", i)
+	// 		q.Run(func(cqe *gouring.CQEntry) {
+	// 			if q.Err() != nil {
+	// 				assert.NoError(t, q.Err(), "run cqe poller")
+	// 				return
+	// 			}
+	// 			defer wg.Wait()
+	// 			fmt.Printf("wrk(%d): %+#v\n", i, cqe)
+	// 		})
+	// 	}(i)
+	// }
+
+	consumerNum := 20
+	for i := 0; i < consumerNum; i++ {
+		go func(i int) {
+			q.Run(true, func(cqe *gouring.CQEntry) (err error) {
+				defer wg.Done()
+				fmt.Printf("wrk(%d): cqe: %+#v\n", i, cqe)
+				assert.Condition(t, func() (success bool) {
+					return cqe.UserData < uint64(len(btests))
+				}, "userdata is set with the btest index")
+				assert.Conditionf(t, func() (success bool) {
+					return len(btests[cqe.UserData]) == int(cqe.Res)
+				}, "OP_WRITE result mismatch: %+#v", cqe)
+
+				return nil
+			})
+		}(i)
+	}
+
+	wg.Wait()
+}
diff --git a/queue/queue_test.go b/queue/queue_test.go
index 73ff415..c2eced3 100644
--- a/queue/queue_test.go
+++ b/queue/queue_test.go
@@ -57,7 +57,7 @@ func TestQueue(t *testing.T) {
 		}
 	}()
 	go func() {
-		q.Run(func(cqe *gouring.CQEntry) {
+		q.Run(true, func(cqe *gouring.CQEntry) (err error) {
 			defer wg.Done()
 			fmt.Printf("cqe: %+#v\n", cqe)
 			assert.Condition(t, func() (success bool) {
@@ -66,74 +66,10 @@ func TestQueue(t *testing.T) {
 			assert.Conditionf(t, func() (success bool) {
 				return len(btests[cqe.UserData]) == int(cqe.Res)
 			}, "OP_WRITE result mismatch: %+#v", cqe)
+
+			return nil
 		})
 	}()
 
 	wg.Wait()
 }
-
-func TestQueueMultiConsumer(t *testing.T) {
-	ring, err := gouring.New(256, nil)
-	assert.NoError(t, err, "create ring")
-	defer func() {
-		err := ring.Close()
-		assert.NoError(t, err, "close ring")
-	}()
-
-	N := 64 + 64
-	var wg sync.WaitGroup
-	btests := [][]byte{}
-	for i := 0; i < N; i++ {
-		btests = append(btests, mkdata(i))
-	}
-	wg.Add(N)
-
-	// create new queue
-	q := New(ring)
-	go func() {
-		for i, b := range btests {
-			sqe := q.GetSQEntry()
-			sqe.UserData = uint64(i)
-			// sqe.Flags = gouring.IOSQE_IO_DRAIN
-			write(sqe, syscall.Stdout, b)
-			if (i+1)%2 == 0 {
-				n, err := q.Submit()
-				assert.NoError(t, err, "queue submit")
-				assert.Equal(t, n, 2, "submit count mismatch")
-				fmt.Printf("submitted %d\n", n)
-			}
-		}
-	}()
-
-	// for i := 0; i < consumerNum; i++ {
-	// 	go func(i int) {
-	// 		fmt.Printf("wrk(%d): start.\n", i)
-	// 		q.Run(func(cqe *gouring.CQEntry) {
-	// 			if q.Err() != nil {
-	// 				assert.NoError(t, q.Err(), "run cqe poller")
-	// 				return
-	// 			}
-	// 			defer wg.Wait()
-	// 			fmt.Printf("wrk(%d): %+#v\n", i, cqe)
-	// 		})
-	// 	}(i)
-	// }
-
-	consumerNum := 20
-	for i := 0; i < consumerNum; i++ {
-		go func(i int) {
-			q.Run(func(cqe *gouring.CQEntry) {
-				defer wg.Done()
-				fmt.Printf("wrk(%d): cqe: %+#v\n", i, cqe)
-				assert.Condition(t, func() (success bool) {
-					return cqe.UserData < uint64(len(btests))
-				}, "userdata is set with the btest index")
-				assert.Conditionf(t, func() (success bool) {
-					return len(btests[cqe.UserData]) == int(cqe.Res)
-				}, "OP_WRITE result mismatch: %+#v", cqe)
-			})
-		}(i)
-	}
-
-	wg.Wait()
-}