diff --git a/queue/queue.go b/queue/queue.go new file mode 100644 index 0000000..79569c6 --- /dev/null +++ b/queue/queue.go @@ -0,0 +1,167 @@ +package queue + +// Modified form +// https://github.com/iceber/io_uring-go types.go + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" + "syscall" + + "github.com/ii64/gouring" +) + +type Queue struct { + ring *gouring.Ring + sq *gouring.SQRing + cq *gouring.CQRing + + sqeHead uint32 + sqeTail uint32 + + cqMx sync.RWMutex // tbd... +} + +func New(ring *gouring.Ring) *Queue { + if ring == nil { + return nil + } + sq := ring.SQ() + cq := ring.CQ() + return &Queue{ + ring: ring, + sq: sq, + cq: cq, + } +} + +// + +func (q *Queue) _getSQEntry() *gouring.SQEntry { + head := atomic.LoadUint32(q.sq.Head()) + next := q.sqeTail + 1 + if (next - head) <= atomic.LoadUint32(q.sq.RingEntries()) { + sqe := q.sq.Get(q.sqeTail & atomic.LoadUint32(q.sq.RingMask())) + q.sqeTail = next + sqe.Reset() + return sqe + } + return nil +} + +func (q *Queue) GetSQEntry() (sqe *gouring.SQEntry) { + for { + sqe = q._getSQEntry() + if sqe != nil { + return + } + runtime.Gosched() + } +} + +func (q *Queue) sqFallback(d uint32) { + q.sqeTail -= d +} + +func (q *Queue) sqFlush() uint32 { + ktail := atomic.LoadUint32(q.sq.Tail()) + if q.sqeHead == q.sqeTail { + return ktail - atomic.LoadUint32(q.sq.Head()) + } + + for toSubmit := q.sqeTail; toSubmit > 0; toSubmit-- { + kmask := *q.sq.RingMask() + *q.sq.Array().Get(ktail & kmask) = q.sqeHead & kmask + + ktail++ + q.sqeHead++ + } + atomic.StoreUint32(q.sq.Tail(), ktail) + return ktail - *q.sq.Head() +} + +func (q *Queue) isNeedEnter(flags *uint32) bool { + if (q.ring.Params().Features & gouring.IORING_SETUP_SQPOLL) > 0 { + return true + } + if q.sq.IsNeedWakeup() { + *flags |= gouring.IORING_SQ_NEED_WAKEUP + return true + } + return false +} + +func (q *Queue) Submit() (ret int, err error) { + submitted := q.sqFlush() + + var flags uint32 + if !q.isNeedEnter(&flags) || submitted == 0 { + return + } + + if q.ring.Params().Flags&gouring.IORING_SETUP_IOPOLL > 0 { + flags |= gouring.IORING_ENTER_GETEVENTS + } + + ret, err = q.ring.Enter(uint(submitted), 0, flags, nil) + return +} + +// + +func (q *Queue) cqPeek() (cqe *gouring.CQEntry) { + if atomic.LoadUint32(q.cq.Head()) != atomic.LoadUint32(q.cq.Tail()) { + cqe = q.cq.Get(atomic.LoadUint32(q.cq.Head()) & atomic.LoadUint32(q.cq.RingMask())) + } + return +} + +func (q *Queue) cqAdvance(d uint32) { + if d != 0 { + atomic.AddUint32(q.cq.Head(), d) + } +} + +func (q *Queue) getCQEvent(wait bool) (cqe *gouring.CQEntry, err error) { + var tryPeeks int + for { + if cqe = q.cqPeek(); cqe != nil { + q.cqAdvance(1) + return + } + + if !wait && !q.sq.IsCQOverflow() { + err = syscall.EAGAIN + return + } + + if q.sq.IsCQOverflow() { + _, err = q.ring.Enter(0, 0, gouring.IORING_ENTER_GETEVENTS, nil) + if err != nil { + return + } + continue + } + + if tryPeeks++; tryPeeks < 3 { + runtime.Gosched() + continue + } + + // implement interrupt + } +} + +func (q *Queue) Run(f func(cqe *gouring.CQEntry)) { + for { + cqe, err := q.getCQEvent(true) + if cqe == nil || err != nil { + fmt.Printf("run error: %+#v\n", err) + continue + } + + f(cqe) + } +} diff --git a/queue/queue_test.go b/queue/queue_test.go new file mode 100644 index 0000000..14f5be7 --- /dev/null +++ b/queue/queue_test.go @@ -0,0 +1,67 @@ +package queue + +import ( + "strings" + "sync" + "syscall" + "testing" + + "github.com/ii64/gouring" + "github.com/stretchr/testify/assert" +) + +func write(sqe *gouring.SQEntry, fd int, b []byte) { + sqe.Opcode = gouring.IORING_OP_WRITE + sqe.Fd = int32(fd) + sqe.Len = uint32(len(b)) + sqe.SetOffset(0) + // *sqe.Addr() = (uint64)(uintptr(unsafe.Pointer(&b[0]))) + sqe.SetAddr(&b[0]) +} + +func TestQueue(t *testing.T) { + ring, err := gouring.New(256, nil) + assert.NoError(t, err, "create ring") + defer func() { + err := ring.Close() + assert.NoError(t, err, "close ring") + }() + + mkdata := func(i int) []byte { + return []byte("queue pls" + strings.Repeat("!", i) + "\n") + } + + N := 5 + var wg sync.WaitGroup + btests := [][]byte{} + for i := 0; i < N; i++ { + btests = append(btests, mkdata(i)) + } + wg.Add(N) + + // create new queue + q := New(ring) + go func() { + for i, b := range btests { + sqe := q.GetSQEntry() + sqe.UserData = uint64(i) + write(sqe, syscall.Stdout, b) + } + n, err := q.Submit() + assert.NoError(t, err, "queue submit") + assert.Equal(t, n, N, "submit count mismatch") + }() + go func() { + q.Run(func(cqe *gouring.CQEntry) { + defer wg.Done() + assert.Condition(t, func() (success bool) { + return cqe.UserData < uint64(len(btests)) + }, "userdata is set with the btest index") + assert.Condition(t, func() (success bool) { + return len(btests[cqe.UserData]) == int(cqe.Res) + }, "OP_WRITE result mismatch") + }) + }() + + wg.Wait() +}