feat(queue): queue implementation wrapper

Signed-off-by: MastahSenpai <26342994+ii64@users.noreply.github.com>
This commit is contained in:
MastahSenpai 2021-12-22 00:33:38 +07:00
parent 173732c949
commit 66638a93b2
Signed by untrusted user who does not match committer: Xeffy
GPG key ID: E41C08AD390E7C49
2 changed files with 234 additions and 0 deletions

167
queue/queue.go Normal file
View file

@ -0,0 +1,167 @@
package queue
// Modified form
// https://github.com/iceber/io_uring-go types.go
import (
"fmt"
"runtime"
"sync"
"sync/atomic"
"syscall"
"github.com/ii64/gouring"
)
type Queue struct {
ring *gouring.Ring
sq *gouring.SQRing
cq *gouring.CQRing
sqeHead uint32
sqeTail uint32
cqMx sync.RWMutex // tbd...
}
func New(ring *gouring.Ring) *Queue {
if ring == nil {
return nil
}
sq := ring.SQ()
cq := ring.CQ()
return &Queue{
ring: ring,
sq: sq,
cq: cq,
}
}
//
func (q *Queue) _getSQEntry() *gouring.SQEntry {
head := atomic.LoadUint32(q.sq.Head())
next := q.sqeTail + 1
if (next - head) <= atomic.LoadUint32(q.sq.RingEntries()) {
sqe := q.sq.Get(q.sqeTail & atomic.LoadUint32(q.sq.RingMask()))
q.sqeTail = next
sqe.Reset()
return sqe
}
return nil
}
func (q *Queue) GetSQEntry() (sqe *gouring.SQEntry) {
for {
sqe = q._getSQEntry()
if sqe != nil {
return
}
runtime.Gosched()
}
}
func (q *Queue) sqFallback(d uint32) {
q.sqeTail -= d
}
func (q *Queue) sqFlush() uint32 {
ktail := atomic.LoadUint32(q.sq.Tail())
if q.sqeHead == q.sqeTail {
return ktail - atomic.LoadUint32(q.sq.Head())
}
for toSubmit := q.sqeTail; toSubmit > 0; toSubmit-- {
kmask := *q.sq.RingMask()
*q.sq.Array().Get(ktail & kmask) = q.sqeHead & kmask
ktail++
q.sqeHead++
}
atomic.StoreUint32(q.sq.Tail(), ktail)
return ktail - *q.sq.Head()
}
func (q *Queue) isNeedEnter(flags *uint32) bool {
if (q.ring.Params().Features & gouring.IORING_SETUP_SQPOLL) > 0 {
return true
}
if q.sq.IsNeedWakeup() {
*flags |= gouring.IORING_SQ_NEED_WAKEUP
return true
}
return false
}
func (q *Queue) Submit() (ret int, err error) {
submitted := q.sqFlush()
var flags uint32
if !q.isNeedEnter(&flags) || submitted == 0 {
return
}
if q.ring.Params().Flags&gouring.IORING_SETUP_IOPOLL > 0 {
flags |= gouring.IORING_ENTER_GETEVENTS
}
ret, err = q.ring.Enter(uint(submitted), 0, flags, nil)
return
}
//
func (q *Queue) cqPeek() (cqe *gouring.CQEntry) {
if atomic.LoadUint32(q.cq.Head()) != atomic.LoadUint32(q.cq.Tail()) {
cqe = q.cq.Get(atomic.LoadUint32(q.cq.Head()) & atomic.LoadUint32(q.cq.RingMask()))
}
return
}
func (q *Queue) cqAdvance(d uint32) {
if d != 0 {
atomic.AddUint32(q.cq.Head(), d)
}
}
func (q *Queue) getCQEvent(wait bool) (cqe *gouring.CQEntry, err error) {
var tryPeeks int
for {
if cqe = q.cqPeek(); cqe != nil {
q.cqAdvance(1)
return
}
if !wait && !q.sq.IsCQOverflow() {
err = syscall.EAGAIN
return
}
if q.sq.IsCQOverflow() {
_, err = q.ring.Enter(0, 0, gouring.IORING_ENTER_GETEVENTS, nil)
if err != nil {
return
}
continue
}
if tryPeeks++; tryPeeks < 3 {
runtime.Gosched()
continue
}
// implement interrupt
}
}
func (q *Queue) Run(f func(cqe *gouring.CQEntry)) {
for {
cqe, err := q.getCQEvent(true)
if cqe == nil || err != nil {
fmt.Printf("run error: %+#v\n", err)
continue
}
f(cqe)
}
}

67
queue/queue_test.go Normal file
View file

@ -0,0 +1,67 @@
package queue
import (
"strings"
"sync"
"syscall"
"testing"
"github.com/ii64/gouring"
"github.com/stretchr/testify/assert"
)
func write(sqe *gouring.SQEntry, fd int, b []byte) {
sqe.Opcode = gouring.IORING_OP_WRITE
sqe.Fd = int32(fd)
sqe.Len = uint32(len(b))
sqe.SetOffset(0)
// *sqe.Addr() = (uint64)(uintptr(unsafe.Pointer(&b[0])))
sqe.SetAddr(&b[0])
}
func TestQueue(t *testing.T) {
ring, err := gouring.New(256, nil)
assert.NoError(t, err, "create ring")
defer func() {
err := ring.Close()
assert.NoError(t, err, "close ring")
}()
mkdata := func(i int) []byte {
return []byte("queue pls" + strings.Repeat("!", i) + "\n")
}
N := 5
var wg sync.WaitGroup
btests := [][]byte{}
for i := 0; i < N; i++ {
btests = append(btests, mkdata(i))
}
wg.Add(N)
// create new queue
q := New(ring)
go func() {
for i, b := range btests {
sqe := q.GetSQEntry()
sqe.UserData = uint64(i)
write(sqe, syscall.Stdout, b)
}
n, err := q.Submit()
assert.NoError(t, err, "queue submit")
assert.Equal(t, n, N, "submit count mismatch")
}()
go func() {
q.Run(func(cqe *gouring.CQEntry) {
defer wg.Done()
assert.Condition(t, func() (success bool) {
return cqe.UserData < uint64(len(btests))
}, "userdata is set with the btest index")
assert.Condition(t, func() (success bool) {
return len(btests[cqe.UserData]) == int(cqe.Res)
}, "OP_WRITE result mismatch")
})
}()
wg.Wait()
}