一个介于wait-free 和 lock-free的高性能MPSC队列

本文实现了一个介于wait-free 和 lock-free的高性能MPSC队列,也就是多生产者但消费者队列。

与使用channel耗时测试:
测试配置:1W个生产者,每个生产者写入90次。统计生产者全部写入队列耗时。

测试 MPSC队列耗时 channel耗时
测试1 115.454185ms 965.9469ms
测试2 97.00209ms 992.595ms
测试3 109.776455ms 989.837ms
测试4 90.19232ms 1.3611857s
测试5 79.12408ms 1.2725355s
测试6 68.650755ms 1.3981461s
测试7 77.155775ms 1.2469043s
测试8 103.899735ms 1.4227174s
测试9 72.15339ms 1.3351006s
测试10 129.974485ms 1.0454128s

package main

import (
	"fmt"
	"runtime"
	"sync"
	"sync/atomic"
	"time"
	"unsafe"
)

type TaskNode struct {
	Data interface{} `json:"data"`
	Next *TaskNode   `json:"Next"`
}

var UNCONNECTED *TaskNode = new(TaskNode)

func NewExecutionQueue(_func func(interface{})) *ExecutionQueue {
	return &ExecutionQueue{
		Head:          nil,
		_execute_func: _func,
		locker:        sync.Mutex{},
		pool: &sync.Pool{New: func() interface{} {
			return new(TaskNode)
		}},
	}
}

type ExecutionQueue struct {
	Head          *TaskNode         `json:"Head"`
	_execute_func func(interface{}) `json:"-"` // 消费者函数
	locker        sync.Mutex        `json:"-"`
	pool          *sync.Pool        `json:"-"`
}

func (ex *ExecutionQueue) AddTaskNode(data interface{}) {
	node := ex.pool.Get().(*TaskNode)
	node.Data = data
	node.Next = UNCONNECTED

	preHead := atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&ex.Head)), unsafe.Pointer(node))

	if preHead != nil {
		node.Next = (*TaskNode)(preHead)
		return
	}

	node.Next = nil
	// 任务不多直接执行,防止线程切换
	ex._execute_func(node.Data)
	if !ex.moreTasks(node) {
		return
	}
	go ex.exectueTasks(node)

}

func (ex *ExecutionQueue) moreTasks(oldNode *TaskNode) bool {

	newHead := oldNode

	if atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&ex.Head)), unsafe.Pointer(newHead), nil) {
		return false
	}
	newHead = (*TaskNode)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&ex.Head))))
	var tail *TaskNode
	p := newHead
	for {
		for {
			if p.Next != UNCONNECTED {
				break
			} else {
				runtime.Gosched()
			}
		}
		saved_next := p.Next
		p.Next = tail
		tail = p
		p = saved_next

		if p == oldNode {
			oldNode.Next = tail
			return true
		}
	}
}

func (ex *ExecutionQueue) exectueTasks(taskNode *TaskNode) {
	for {
		tmp := taskNode

		taskNode = taskNode.Next
		tmp.Next = nil
		ex.pool.Put(tmp)
		ex._execute_func(taskNode.Data)

		if taskNode.Next == nil && !ex.moreTasks(taskNode) {
			return
		}
	}
}

var count int64 = 0

func print(data interface{}) {
	// a := count
	// _ = a
	_ = data.(int) * data.(int)
	atomic.AddInt64(&count, 1)
	// fmt.Println(data.(int))
}
func Test1() {
	var singalexit = sync.WaitGroup{}
	ex := NewExecutionQueue(print)
	start := time.Now()
	var s string
	for k := 0; k < 20; k++ {
		for i := 0; i < 10000; i++ {
			singalexit.Add(1)
			go func(i int, singalexit *sync.WaitGroup) {
				defer singalexit.Done()
				for j := 0; j < 90; j++ {
					ex.AddTaskNode(i*100 + j)
				}
			}(i, &singalexit)
			_ = s

		}
	}
	singalexit.Wait()
	elapsed := time.Since(start)
	fmt.Println("该函数执行完成耗时:", elapsed/20)
	time.Sleep(2 * time.Second)
	fmt.Println(atomic.LoadInt64(&count))
}

func Test2() {
	var singalexit sync.WaitGroup
	data := make(chan int, 2000)
	var count1 int64 = 0
	go func() {
		for {
			<-data
			atomic.AddInt64(&count1, 1)
		}

	}()
	start := time.Now()
	func() {
		for i := 0; i < 10000; i++ {
			singalexit.Add(1)
			go func(i int) {

				defer singalexit.Done()
				for j := 0; j < 90; j++ {
					data <- (i*100 + j)
				}
			}(i)
		}
	}()
	singalexit.Wait()
	elapsed := time.Since(start)
	fmt.Println("该函数执行完成耗时:", elapsed)
	time.Sleep(2 * time.Second)
	fmt.Println(atomic.LoadInt64(&count1))

}

func main() {
	for i := 0; i < 10; i++ {
		count = 0
		Test1()
	}
	for i := 0; i < 10; i++ {
		Test2()
	}

	// Test2()
}
not found!