1
0
Fork 0
mirror of https://github.com/ii64/gouring.git synced 2025-04-26 05:42:48 +02:00

fix(queue): race on queue

Signed-off-by: MastahSenpai <26342994+ii64@users.noreply.github.com>
This commit is contained in:
MastahSenpai 2021-12-24 12:53:42 +07:00
parent 914abe7c45
commit 2633fb50f9
Signed by untrusted user who does not match committer: Xeffy
GPG key ID: E41C08AD390E7C49
2 changed files with 86 additions and 11 deletions

View file

@ -26,7 +26,9 @@ type Queue struct {
sqeTail uint32
sMx sync.Mutex
cqMx sync.RWMutex // tbd...
cqMx sync.Mutex
err error
clq uint32
}
@ -49,12 +51,12 @@ func (q *Queue) Close() error {
atomic.StoreUint32(&q.clq, 1)
return nil
}
func (q *Queue) precheck() (err error) {
func (q *Queue) precheck() error {
if clq := atomic.LoadUint32(&q.clq); clq == 1 {
err = ErrQueueClosed
return
q.err = ErrQueueClosed
return q.err
}
return
return nil
}
//
@ -142,11 +144,13 @@ func (q *Queue) cqPeek() (cqe *gouring.CQEntry) {
func (q *Queue) cqAdvance(d uint32) {
if d != 0 {
atomic.AddUint32(q.cq.Head(), d)
atomic.AddUint32(q.cq.Head(), d) // mark readed
}
}
func (q *Queue) GetCQEvent(wait bool) (cqe *gouring.CQEntry, err error) {
q.cqMx.Lock()
defer q.cqMx.Unlock()
if err = q.precheck(); err != nil {
return
}
@ -179,11 +183,15 @@ func (q *Queue) GetCQEvent(wait bool) (cqe *gouring.CQEntry, err error) {
}
}
func (q *Queue) Err() error {
return q.err
}
func (q *Queue) Run(f func(cqe *gouring.CQEntry)) {
for q.precheck() == nil {
cqe, err := q.GetCQEvent(true)
if cqe == nil || err != nil {
// fmt.Printf("run error: %+#v\n", err)
q.err = err
if err == ErrQueueClosed {
return
}

View file

@ -20,6 +20,10 @@ func write(sqe *gouring.SQEntry, fd int, b []byte) {
sqe.SetAddr(&b[0])
}
func mkdata(i int) []byte {
return []byte("queue pls" + strings.Repeat("!", i) + fmt.Sprintf("%d", i) + "\n")
}
func TestQueue(t *testing.T) {
ring, err := gouring.New(256, nil)
assert.NoError(t, err, "create ring")
@ -28,10 +32,6 @@ func TestQueue(t *testing.T) {
assert.NoError(t, err, "close ring")
}()
mkdata := func(i int) []byte {
return []byte("queue pls" + strings.Repeat("!", i) + fmt.Sprintf("%d", i) + "\n")
}
N := 64 + 64
var wg sync.WaitGroup
btests := [][]byte{}
@ -59,6 +59,7 @@ func TestQueue(t *testing.T) {
go func() {
q.Run(func(cqe *gouring.CQEntry) {
defer wg.Done()
fmt.Printf("cqe: %+#v\n", cqe)
assert.Condition(t, func() (success bool) {
return cqe.UserData < uint64(len(btests))
}, "userdata is set with the btest index")
@ -70,3 +71,69 @@ func TestQueue(t *testing.T) {
wg.Wait()
}
func TestQueueMultiConsumer(t *testing.T) {
ring, err := gouring.New(256, nil)
assert.NoError(t, err, "create ring")
defer func() {
err := ring.Close()
assert.NoError(t, err, "close ring")
}()
N := 64 + 64
var wg sync.WaitGroup
btests := [][]byte{}
for i := 0; i < N; i++ {
btests = append(btests, mkdata(i))
}
wg.Add(N)
// create new queue
q := New(ring)
go func() {
for i, b := range btests {
sqe := q.GetSQEntry()
sqe.UserData = uint64(i)
// sqe.Flags = gouring.IOSQE_IO_DRAIN
write(sqe, syscall.Stdout, b)
if (i+1)%2 == 0 {
n, err := q.Submit()
assert.NoError(t, err, "queue submit")
assert.Equal(t, n, 2, "submit count mismatch")
fmt.Printf("submitted %d\n", n)
}
}
}()
// for i := 0; i < consumerNum; i++ {
// go func(i int) {
// fmt.Printf("wrk(%d): start.\n", i)
// q.Run(func(cqe *gouring.CQEntry) {
// if q.Err() != nil {
// assert.NoError(t, q.Err(), "run cqe poller")
// return
// }
// defer wg.Wait()
// fmt.Printf("wrk(%d): %+#v\n", i, cqe)
// })
// }(i)
// }
consumerNum := 20
for i := 0; i < consumerNum; i++ {
go func(i int) {
q.Run(func(cqe *gouring.CQEntry) {
defer wg.Done()
fmt.Printf("wrk(%d): cqe: %+#v\n", i, cqe)
assert.Condition(t, func() (success bool) {
return cqe.UserData < uint64(len(btests))
}, "userdata is set with the btest index")
assert.Conditionf(t, func() (success bool) {
return len(btests[cqe.UserData]) == int(cqe.Res)
}, "OP_WRITE result mismatch: %+#v", cqe)
})
}(i)
}
wg.Wait()
}