1
0
Fork 0
mirror of https://github.com/ii64/gouring.git synced 2025-04-01 03:41:44 +02:00

fix(queue): queue out of sync, test

Signed-off-by: MastahSenpai <26342994+ii64@users.noreply.github.com>
This commit is contained in:
MastahSenpai 2021-12-23 00:06:53 +07:00
parent 3661a48501
commit 4742c90656
Signed by untrusted user who does not match committer: Xeffy
GPG key ID: E41C08AD390E7C49
2 changed files with 51 additions and 17 deletions

View file

@ -4,6 +4,7 @@ package queue
// https://github.com/iceber/io_uring-go types.go
import (
"errors"
"fmt"
"runtime"
"sync"
@ -13,6 +14,10 @@ import (
"github.com/ii64/gouring"
)
var (
ErrQueueClosed = errors.New("queue closed")
)
type Queue struct {
ring *gouring.Ring
sq *gouring.SQRing
@ -21,7 +26,10 @@ type Queue struct {
sqeHead uint32
sqeTail uint32
sMx sync.Mutex
cqMx sync.RWMutex // tbd...
clq uint32
}
func New(ring *gouring.Ring) *Queue {
@ -34,9 +42,22 @@ func New(ring *gouring.Ring) *Queue {
ring: ring,
sq: sq,
cq: cq,
clq: 0,
}
}
func (q *Queue) Close() error {
atomic.StoreUint32(&q.clq, 1)
return nil
}
func (q *Queue) precheck() (err error) {
if clq := atomic.LoadUint32(&q.clq); clq == 1 {
err = ErrQueueClosed
return
}
return
}
//
func (q *Queue) _getSQEntry() *gouring.SQEntry {
@ -66,18 +87,17 @@ func (q *Queue) sqFallback(d uint32) {
}
func (q *Queue) sqFlush() uint32 {
ktail := atomic.LoadUint32(q.sq.Tail())
if q.sqeHead == q.sqeTail {
return ktail - atomic.LoadUint32(q.sq.Head())
return atomic.LoadUint32(q.sq.Tail()) - atomic.LoadUint32(q.sq.Head())
}
for toSubmit := q.sqeTail; toSubmit > 0; toSubmit-- {
kmask := *q.sq.RingMask()
*q.sq.Array().Get(ktail & kmask) = q.sqeHead & kmask
ktail := atomic.LoadUint32(q.sq.Tail())
for toSubmit := q.sqeTail - q.sqeHead; toSubmit > 0; toSubmit-- {
*q.sq.Array().Get(ktail & (*q.sq.RingMask())) = q.sqeHead & (*q.sq.RingMask())
ktail++
q.sqeHead++
}
atomic.StoreUint32(q.sq.Tail(), ktail)
return ktail - *q.sq.Head()
}
@ -94,6 +114,8 @@ func (q *Queue) isNeedEnter(flags *uint32) bool {
}
func (q *Queue) Submit() (ret int, err error) {
q.sMx.Lock()
defer q.sMx.Unlock()
submitted := q.sqFlush()
var flags uint32
@ -101,7 +123,7 @@ func (q *Queue) Submit() (ret int, err error) {
return
}
if q.ring.Params().Flags&gouring.IORING_SETUP_IOPOLL > 0 {
if (q.ring.Params().Flags & gouring.IORING_SETUP_IOPOLL) > 0 {
flags |= gouring.IORING_ENTER_GETEVENTS
}
@ -112,8 +134,9 @@ func (q *Queue) Submit() (ret int, err error) {
//
func (q *Queue) cqPeek() (cqe *gouring.CQEntry) {
if atomic.LoadUint32(q.cq.Head()) != atomic.LoadUint32(q.cq.Tail()) {
cqe = q.cq.Get(atomic.LoadUint32(q.cq.Head()) & atomic.LoadUint32(q.cq.RingMask()))
head := atomic.LoadUint32(q.cq.Head())
if head != atomic.LoadUint32(q.cq.Tail()) {
cqe = q.cq.Get(head & atomic.LoadUint32(q.cq.RingMask()))
}
return
}
@ -125,6 +148,9 @@ func (q *Queue) cqAdvance(d uint32) {
}
func (q *Queue) getCQEvent(wait bool) (cqe *gouring.CQEntry, err error) {
if err = q.precheck(); err != nil {
return
}
var tryPeeks int
for {
if cqe = q.cqPeek(); cqe != nil {
@ -155,10 +181,13 @@ func (q *Queue) getCQEvent(wait bool) (cqe *gouring.CQEntry, err error) {
}
func (q *Queue) Run(f func(cqe *gouring.CQEntry)) {
for {
for q.precheck() == nil {
cqe, err := q.getCQEvent(true)
if cqe == nil || err != nil {
fmt.Printf("run error: %+#v\n", err)
if err == ErrQueueClosed {
return
}
continue
}

View file

@ -1,6 +1,7 @@
package queue
import (
"fmt"
"strings"
"sync"
"syscall"
@ -28,10 +29,10 @@ func TestQueue(t *testing.T) {
}()
mkdata := func(i int) []byte {
return []byte("queue pls" + strings.Repeat("!", i) + "\n")
return []byte("queue pls" + strings.Repeat("!", i) + fmt.Sprintf("%d", i) + "\n")
}
N := 5
N := 64 + 64
var wg sync.WaitGroup
btests := [][]byte{}
for i := 0; i < N; i++ {
@ -45,11 +46,15 @@ func TestQueue(t *testing.T) {
for i, b := range btests {
sqe := q.GetSQEntry()
sqe.UserData = uint64(i)
// sqe.Flags = gouring.IOSQE_IO_DRAIN
write(sqe, syscall.Stdout, b)
if (i+1)%2 == 0 {
n, err := q.Submit()
assert.NoError(t, err, "queue submit")
assert.Equal(t, n, 2, "submit count mismatch")
fmt.Printf("submitted %d\n", n)
}
}
n, err := q.Submit()
assert.NoError(t, err, "queue submit")
assert.Equal(t, n, N, "submit count mismatch")
}()
go func() {
q.Run(func(cqe *gouring.CQEntry) {
@ -57,9 +62,9 @@ func TestQueue(t *testing.T) {
assert.Condition(t, func() (success bool) {
return cqe.UserData < uint64(len(btests))
}, "userdata is set with the btest index")
assert.Condition(t, func() (success bool) {
assert.Conditionf(t, func() (success bool) {
return len(btests[cqe.UserData]) == int(cqe.Res)
}, "OP_WRITE result mismatch")
}, "OP_WRITE result mismatch: %+#v", cqe)
})
}()