diff --git a/queue/queue.go b/queue/queue.go index 79569c6..0d33c6c 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -4,6 +4,7 @@ package queue // https://github.com/iceber/io_uring-go types.go import ( + "errors" "fmt" "runtime" "sync" @@ -13,6 +14,10 @@ import ( "github.com/ii64/gouring" ) +var ( + ErrQueueClosed = errors.New("queue closed") +) + type Queue struct { ring *gouring.Ring sq *gouring.SQRing @@ -21,7 +26,10 @@ type Queue struct { sqeHead uint32 sqeTail uint32 + sMx sync.Mutex cqMx sync.RWMutex // tbd... + + clq uint32 } func New(ring *gouring.Ring) *Queue { @@ -34,9 +42,22 @@ func New(ring *gouring.Ring) *Queue { ring: ring, sq: sq, cq: cq, + clq: 0, } } +func (q *Queue) Close() error { + atomic.StoreUint32(&q.clq, 1) + return nil +} +func (q *Queue) precheck() (err error) { + if clq := atomic.LoadUint32(&q.clq); clq == 1 { + err = ErrQueueClosed + return + } + return +} + // func (q *Queue) _getSQEntry() *gouring.SQEntry { @@ -66,18 +87,17 @@ func (q *Queue) sqFallback(d uint32) { } func (q *Queue) sqFlush() uint32 { - ktail := atomic.LoadUint32(q.sq.Tail()) if q.sqeHead == q.sqeTail { - return ktail - atomic.LoadUint32(q.sq.Head()) + return atomic.LoadUint32(q.sq.Tail()) - atomic.LoadUint32(q.sq.Head()) } - for toSubmit := q.sqeTail; toSubmit > 0; toSubmit-- { - kmask := *q.sq.RingMask() - *q.sq.Array().Get(ktail & kmask) = q.sqeHead & kmask - + ktail := atomic.LoadUint32(q.sq.Tail()) + for toSubmit := q.sqeTail - q.sqeHead; toSubmit > 0; toSubmit-- { + *q.sq.Array().Get(ktail & (*q.sq.RingMask())) = q.sqeHead & (*q.sq.RingMask()) ktail++ q.sqeHead++ } + atomic.StoreUint32(q.sq.Tail(), ktail) return ktail - *q.sq.Head() } @@ -94,6 +114,8 @@ func (q *Queue) isNeedEnter(flags *uint32) bool { } func (q *Queue) Submit() (ret int, err error) { + q.sMx.Lock() + defer q.sMx.Unlock() submitted := q.sqFlush() var flags uint32 @@ -101,7 +123,7 @@ func (q *Queue) Submit() (ret int, err error) { return } - if q.ring.Params().Flags&gouring.IORING_SETUP_IOPOLL > 0 { + if (q.ring.Params().Flags & gouring.IORING_SETUP_IOPOLL) > 0 { flags |= gouring.IORING_ENTER_GETEVENTS } @@ -112,8 +134,9 @@ func (q *Queue) Submit() (ret int, err error) { // func (q *Queue) cqPeek() (cqe *gouring.CQEntry) { - if atomic.LoadUint32(q.cq.Head()) != atomic.LoadUint32(q.cq.Tail()) { - cqe = q.cq.Get(atomic.LoadUint32(q.cq.Head()) & atomic.LoadUint32(q.cq.RingMask())) + head := atomic.LoadUint32(q.cq.Head()) + if head != atomic.LoadUint32(q.cq.Tail()) { + cqe = q.cq.Get(head & atomic.LoadUint32(q.cq.RingMask())) } return } @@ -125,6 +148,9 @@ func (q *Queue) cqAdvance(d uint32) { } func (q *Queue) getCQEvent(wait bool) (cqe *gouring.CQEntry, err error) { + if err = q.precheck(); err != nil { + return + } var tryPeeks int for { if cqe = q.cqPeek(); cqe != nil { @@ -155,10 +181,13 @@ func (q *Queue) getCQEvent(wait bool) (cqe *gouring.CQEntry, err error) { } func (q *Queue) Run(f func(cqe *gouring.CQEntry)) { - for { + for q.precheck() == nil { cqe, err := q.getCQEvent(true) if cqe == nil || err != nil { fmt.Printf("run error: %+#v\n", err) + if err == ErrQueueClosed { + return + } continue } diff --git a/queue/queue_test.go b/queue/queue_test.go index 14f5be7..34bce24 100644 --- a/queue/queue_test.go +++ b/queue/queue_test.go @@ -1,6 +1,7 @@ package queue import ( + "fmt" "strings" "sync" "syscall" @@ -28,10 +29,10 @@ func TestQueue(t *testing.T) { }() mkdata := func(i int) []byte { - return []byte("queue pls" + strings.Repeat("!", i) + "\n") + return []byte("queue pls" + strings.Repeat("!", i) + fmt.Sprintf("%d", i) + "\n") } - N := 5 + N := 64 + 64 var wg sync.WaitGroup btests := [][]byte{} for i := 0; i < N; i++ { @@ -45,11 +46,15 @@ func TestQueue(t *testing.T) { for i, b := range btests { sqe := q.GetSQEntry() sqe.UserData = uint64(i) + // sqe.Flags = gouring.IOSQE_IO_DRAIN write(sqe, syscall.Stdout, b) + if (i+1)%2 == 0 { + n, err := q.Submit() + assert.NoError(t, err, "queue submit") + assert.Equal(t, n, 2, "submit count mismatch") + fmt.Printf("submitted %d\n", n) + } } - n, err := q.Submit() - assert.NoError(t, err, "queue submit") - assert.Equal(t, n, N, "submit count mismatch") }() go func() { q.Run(func(cqe *gouring.CQEntry) { @@ -57,9 +62,9 @@ func TestQueue(t *testing.T) { assert.Condition(t, func() (success bool) { return cqe.UserData < uint64(len(btests)) }, "userdata is set with the btest index") - assert.Condition(t, func() (success bool) { + assert.Conditionf(t, func() (success bool) { return len(btests[cqe.UserData]) == int(cqe.Res) - }, "OP_WRITE result mismatch") + }, "OP_WRITE result mismatch: %+#v", cqe) }) }()