From 64b72d47c8984ca9f42d43ddef268b1a9e2eef17 Mon Sep 17 00:00:00 2001 From: MastahSenpai <26342994+ii64@users.noreply.github.com> Date: Sun, 26 Dec 2021 18:04:46 +0700 Subject: [PATCH] feat(queue): lock-free, and thread-safe Signed-off-by: MastahSenpai <26342994+ii64@users.noreply.github.com> --- queue/queue.go | 28 +++++++------- queue/queue_lock.go | 54 +++++++++++++++++++++++++++ queue/queue_lock_test.go | 79 ++++++++++++++++++++++++++++++++++++++++ queue/queue_test.go | 70 ++--------------------------------- 4 files changed, 151 insertions(+), 80 deletions(-) create mode 100644 queue/queue_lock.go create mode 100644 queue/queue_lock_test.go diff --git a/queue/queue.go b/queue/queue.go index 5a05420..08eda32 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -6,7 +6,6 @@ package queue import ( "errors" "runtime" - "sync" "sync/atomic" "syscall" @@ -17,6 +16,8 @@ var ( ErrQueueClosed = errors.New("queue closed") ) +type QueueCQEHandler func(cqe *gouring.CQEntry) (err error) + type Queue struct { ring *gouring.Ring sq *gouring.SQRing @@ -25,9 +26,6 @@ type Queue struct { sqeHead uint32 sqeTail uint32 - sMx sync.Mutex - cqMx sync.Mutex - err error clq uint32 @@ -115,8 +113,8 @@ func (q *Queue) isNeedEnter(flags *uint32) bool { } func (q *Queue) Submit() (ret int, err error) { - q.sMx.Lock() - defer q.sMx.Unlock() + // q.sMx.Lock() + // defer q.sMx.Unlock() submitted := q.sqFlush() var flags uint32 @@ -148,9 +146,9 @@ func (q *Queue) cqAdvance(d uint32) { } } -func (q *Queue) GetCQEvent(wait bool) (cqe *gouring.CQEntry, err error) { - q.cqMx.Lock() - defer q.cqMx.Unlock() +func (q *Queue) GetCQEntry(wait bool) (cqe *gouring.CQEntry, err error) { + // q.cqMx.Lock() + // defer q.cqMx.Unlock() if err = q.precheck(); err != nil { return } @@ -187,17 +185,21 @@ func (q *Queue) Err() error { return q.err } -func (q *Queue) Run(f func(cqe *gouring.CQEntry)) { +func (q *Queue) Run(wait bool, f QueueCQEHandler) (err error) { for q.precheck() == nil { - cqe, err := q.GetCQEvent(true) + cqe, err := q.GetCQEntry(wait) if cqe == nil || err != nil { q.err = err if err == ErrQueueClosed { - return + return err } continue } - f(cqe) + err = f(cqe) + if err != nil { + return err + } } + return nil } diff --git a/queue/queue_lock.go b/queue/queue_lock.go new file mode 100644 index 0000000..2551f7b --- /dev/null +++ b/queue/queue_lock.go @@ -0,0 +1,54 @@ +package queue + +import ( + "sync" + + "github.com/ii64/gouring" +) + +type QueueLocks struct { + *Queue + + sMx sync.Mutex + cqMx sync.Mutex +} + +func NewWithLocks(ring *gouring.Ring) *QueueLocks { + q := &QueueLocks{ + Queue: New(ring), + } + return q +} + +func (q *QueueLocks) Submit() (ret int, err error) { + q.sMx.Lock() + defer q.sMx.Unlock() + return q.Queue.Submit() +} + +// + +func (q *QueueLocks) GetCQEntry(wait bool) (cqe *gouring.CQEntry, err error) { + q.cqMx.Lock() + defer q.cqMx.Unlock() + return q.Queue.GetCQEntry(wait) +} + +func (q *QueueLocks) Run(wait bool, f QueueCQEHandler) (err error) { + for q.precheck() == nil { + cqe, err := q.GetCQEntry(wait) + if cqe == nil || err != nil { + q.err = err + if err == ErrQueueClosed { + return err + } + continue + } + + err = f(cqe) + if err != nil { + return err + } + } + return nil +} diff --git a/queue/queue_lock_test.go b/queue/queue_lock_test.go new file mode 100644 index 0000000..879cb18 --- /dev/null +++ b/queue/queue_lock_test.go @@ -0,0 +1,79 @@ +package queue + +import ( + "fmt" + "sync" + "syscall" + "testing" + + "github.com/ii64/gouring" + "github.com/stretchr/testify/assert" +) + +func TestQueueMultiConsumer(t *testing.T) { + ring, err := gouring.New(256, nil) + assert.NoError(t, err, "create ring") + defer func() { + err := ring.Close() + assert.NoError(t, err, "close ring") + }() + + N := 64 + 64 + var wg sync.WaitGroup + btests := [][]byte{} + for i := 0; i < N; i++ { + btests = append(btests, mkdata(i)) + } + wg.Add(N) + + // create new queue + q := NewWithLocks(ring) + go func() { + for i, b := range btests { + sqe := q.GetSQEntry() + sqe.UserData = uint64(i) + // sqe.Flags = gouring.IOSQE_IO_DRAIN + write(sqe, syscall.Stdout, b) + if (i+1)%2 == 0 { + n, err := q.Submit() + assert.NoError(t, err, "queue submit") + assert.Equal(t, n, 2, "submit count mismatch") + fmt.Printf("submitted %d\n", n) + } + } + }() + + // for i := 0; i < consumerNum; i++ { + // go func(i int) { + // fmt.Printf("wrk(%d): start.\n", i) + // q.Run(func(cqe *gouring.CQEntry) { + // if q.Err() != nil { + // assert.NoError(t, q.Err(), "run cqe poller") + // return + // } + // defer wg.Wait() + // fmt.Printf("wrk(%d): %+#v\n", i, cqe) + // }) + // }(i) + // } + + consumerNum := 20 + for i := 0; i < consumerNum; i++ { + go func(i int) { + q.Run(true, func(cqe *gouring.CQEntry) (err error) { + defer wg.Done() + fmt.Printf("wrk(%d): cqe: %+#v\n", i, cqe) + assert.Condition(t, func() (success bool) { + return cqe.UserData < uint64(len(btests)) + }, "userdata is set with the btest index") + assert.Conditionf(t, func() (success bool) { + return len(btests[cqe.UserData]) == int(cqe.Res) + }, "OP_WRITE result mismatch: %+#v", cqe) + + return nil + }) + }(i) + } + + wg.Wait() +} diff --git a/queue/queue_test.go b/queue/queue_test.go index 73ff415..c2eced3 100644 --- a/queue/queue_test.go +++ b/queue/queue_test.go @@ -57,7 +57,7 @@ func TestQueue(t *testing.T) { } }() go func() { - q.Run(func(cqe *gouring.CQEntry) { + q.Run(true, func(cqe *gouring.CQEntry) (err error) { defer wg.Done() fmt.Printf("cqe: %+#v\n", cqe) assert.Condition(t, func() (success bool) { @@ -66,74 +66,10 @@ func TestQueue(t *testing.T) { assert.Conditionf(t, func() (success bool) { return len(btests[cqe.UserData]) == int(cqe.Res) }, "OP_WRITE result mismatch: %+#v", cqe) + + return nil }) }() wg.Wait() } - -func TestQueueMultiConsumer(t *testing.T) { - ring, err := gouring.New(256, nil) - assert.NoError(t, err, "create ring") - defer func() { - err := ring.Close() - assert.NoError(t, err, "close ring") - }() - - N := 64 + 64 - var wg sync.WaitGroup - btests := [][]byte{} - for i := 0; i < N; i++ { - btests = append(btests, mkdata(i)) - } - wg.Add(N) - - // create new queue - q := New(ring) - go func() { - for i, b := range btests { - sqe := q.GetSQEntry() - sqe.UserData = uint64(i) - // sqe.Flags = gouring.IOSQE_IO_DRAIN - write(sqe, syscall.Stdout, b) - if (i+1)%2 == 0 { - n, err := q.Submit() - assert.NoError(t, err, "queue submit") - assert.Equal(t, n, 2, "submit count mismatch") - fmt.Printf("submitted %d\n", n) - } - } - }() - - // for i := 0; i < consumerNum; i++ { - // go func(i int) { - // fmt.Printf("wrk(%d): start.\n", i) - // q.Run(func(cqe *gouring.CQEntry) { - // if q.Err() != nil { - // assert.NoError(t, q.Err(), "run cqe poller") - // return - // } - // defer wg.Wait() - // fmt.Printf("wrk(%d): %+#v\n", i, cqe) - // }) - // }(i) - // } - - consumerNum := 20 - for i := 0; i < consumerNum; i++ { - go func(i int) { - q.Run(func(cqe *gouring.CQEntry) { - defer wg.Done() - fmt.Printf("wrk(%d): cqe: %+#v\n", i, cqe) - assert.Condition(t, func() (success bool) { - return cqe.UserData < uint64(len(btests)) - }, "userdata is set with the btest index") - assert.Conditionf(t, func() (success bool) { - return len(btests[cqe.UserData]) == int(cqe.Res) - }, "OP_WRITE result mismatch: %+#v", cqe) - }) - }(i) - } - - wg.Wait() -}