1. CyclicBarrier 概述
CyclicBarrier 是一个可重用的栅栏并发原语,常常应用于重复进行一组 goroutine 同时执行的场景中。
CyclicBarrier允许一组 goroutine 彼此等待,到达一个共同的执行点。同时,因为它可以被重复使用,所以叫循环栅栏。具体的机制是,大家都在栅栏前等待,等全部都到齐了,就抬起栅栏放行。
1.1 CyclicBarrier 与 WaitGroup
你可能会觉得,CyclicBarrier 和 WaitGroup 的功能有点类似,确实是这样。不过还是有区别的:
- CyclicBarrier 更适合用在“固定数量的 goroutine 等待同一个执行点”的场景中,
- 而且在放行 goroutine 之后,CyclicBarrier 可以重复利用,
- 不像 WaitGroup 重用的时候,必须小心翼翼避免 panic。
处理可重用的多 goroutine 等待同一个执行点的场景的时候,CyclicBarrier 和 WaitGroup 方法调用的对应关系如下:
如果使用 WaitGroup 实现的话,调用比较复杂,不像 CyclicBarrier 那么清爽。更重要的是,如果想重用 WaitGroup,你还要保证,将 WaitGroup 的计数值重置到 n 的时候不会出现并发问题。WaitGroup 更适合用在“一个 goroutine 等待一组 goroutine 到达同一个执行点”的场景中,或者是不需要重用的场景中。
1.2 CyclicBarrier 使用
CyclicBarrier 有两个初始化方法:
- 第一个是 New 方法,它只需要一个参数,来指定循环栅栏参与者的数量;
- 第二个方法是 NewWithAction
- 它额外提供一个函数,可以在每一次到达执行点的时候执行一次
- 执行具体的时间点是在最后一个参与者到达之后,但是其它的参与者还未被放行之前。我们可以利用它,做放行之前的一些共享状态的更新等操作。
1
2
3
|
func New(parties int) CyclicBarrier
func NewWithAction(parties int, barrierAction func() error) CyclicBarrier
|
CyclicBarrier 是一个接口,定义的方法如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
type CyclicBarrier interface {
// 等待所有的参与者到达,如果被ctx.Done()中断,会返回ErrBrokenBarrier
Await(ctx context.Context) error
// 重置循环栅栏到初始化状态。如果当前有等待者,那么它们会返回ErrBrokenBarrier
Reset()
// 返回当前等待者的数量
GetNumberWaiting() int
// 参与者的数量
GetParties() int
// 循环栅栏是否处于中断状态
IsBroken() bool
}
|
循环栅栏的使用也很简单。循环栅栏的参与者只需调用 Await 等待,等所有的参与者都到达后,再执行下一步。当执行下一步的时候,循环栅栏的状态又恢复到初始的状态了,可以迎接下一轮同样多的参与者。下面是一个使用示例: 生产水原子,每生产一个水分子,就会打印出 HHO、HOH、OHH 三种形式的其中一种。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
package water
import (
"context"
"github.com/marusama/cyclicbarrier"
"golang.org/x/sync/semaphore"
)
// 定义水分子合成的辅助数据结构
type H2O struct {
semaH *semaphore.Weighted // 氢原子的信号量
semaO *semaphore.Weighted // 氧原子的信号量
b cyclicbarrier.CyclicBarrier // 循环栅栏,用来控制合成
}
func New() *H2O {
return &H2O{
semaH: semaphore.NewWeighted(2), //氢原子需要两个
semaO: semaphore.NewWeighted(1), // 氧原子需要一个
b: cyclicbarrier.New(3), // 需要三个原子才能合成
}
}
func (h2o *H2O) hydrogen(releaseHydrogen func()) {
h2o.semaH.Acquire(context.Background(), 1)
releaseHydrogen() // 输出H
h2o.b.Await(context.Background()) //等待栅栏放行
h2o.semaH.Release(1) // 释放氢原子空槽
}
func (h2o *H2O) oxygen(releaseOxygen func()) {
h2o.semaO.Acquire(context.Background(), 1)
releaseOxygen() // 输出O
h2o.b.Await(context.Background()) //等待栅栏放行
h2o.semaO.Release(1) // 释放氢原子空槽
}
|
下面是对应的单元测试
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
|
package water
import (
"math/rand"
"sort"
"sync"
"testing"
"time"
)
func TestWaterFactory(t *testing.T) {
//用来存放水分子结果的channel
var ch chan string
releaseHydrogen := func() {
ch <- "H"
}
releaseOxygen := func() {
ch <- "O"
}
// 300个原子,300个goroutine,每个goroutine并发的产生一个原子
var N = 100
ch = make(chan string, N*3)
h2o := New()
// 用来等待所有的goroutine完成
var wg sync.WaitGroup
wg.Add(N * 3)
// 200个氢原子goroutine
for i := 0; i < 2*N; i++ {
go func() {
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
h2o.hydrogen(releaseHydrogen)
wg.Done()
}()
}
// 100个氧原子goroutine
for i := 0; i < N; i++ {
go func() {
time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
h2o.oxygen(releaseOxygen)
wg.Done()
}()
}
//等待所有的goroutine执行完
wg.Wait()
// 结果中肯定是300个原子
if len(ch) != N*3 {
t.Fatalf("expect %d atom but got %d", N*3, len(ch))
}
// 每三个原子一组,分别进行检查。要求这一组原子中必须包含两个氢原子和一个氧原子,这样才能正确组成一个水分子。
var s = make([]string, 3)
for i := 0; i < N; i++ {
s[0] = <-ch
s[1] = <-ch
s[2] = <-ch
sort.Strings(s)
water := s[0] + s[1] + s[2]
if water != "HHO" {
t.Fatalf("expect a water molecule but got %s", water)
}
}
}
|
如果你没有学习 CyclicBarrier,你可能只会想到,用 WaitGroup 来实现这个水分子制造工厂的例子。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
type H2O struct {
semaH *semaphore.Weighted
semaO *semaphore.Weighted
wg sync.WaitGroup //将循环栅栏替换成WaitGroup
}
func New() *H2O {
var wg sync.WaitGroup
wg.Add(3)
return &H2O{
semaH: semaphore.NewWeighted(2),
semaO: semaphore.NewWeighted(1),
wg: wg,
}
}
func (h2o *H2O) hydrogen(releaseHydrogen func()) {
h2o.semaH.Acquire(context.Background(), 1)
releaseHydrogen()
// 标记自己已达到,等待其它goroutine到达
h2o.wg.Done()
h2o.wg.Wait()
h2o.semaH.Release(1)
}
func (h2o *H2O) oxygen(releaseOxygen func()) {
h2o.semaO.Acquire(context.Background(), 1)
releaseOxygen()
// 标记自己已达到,等待其它goroutine到达
h2o.wg.Done()
h2o.wg.Wait()
//都到达后重置wg
h2o.wg.Add(3)
h2o.semaO.Release(1)
}
|
使用 WaitGroup 非常复杂,而且,重用和 Done 方法的调用有并发的问题,程序可能 panic,远远没有使用循环栅栏更加简单直接。
1.3 CyclicBarrier 的实现
CyclicBarrier 的数据结构
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
// round
type round struct {
count int // count of goroutines for this roundtrip
waitCh chan struct{} // wait channel for this roundtrip
brokeCh chan struct{} // channel for isBroken broadcast
isBroken bool // is barrier broken
}
// cyclicBarrier impl CyclicBarrier intf
type cyclicBarrier struct {
parties int
barrierAction func() error
lock sync.RWMutex
round *round
}
// New initializes a new instance of the CyclicBarrier, specifying the number of parties.
func New(parties int) CyclicBarrier {
if parties <= 0 {
panic("parties must be positive number")
}
return &cyclicBarrier{
parties: parties,
lock: sync.RWMutex{},
round: &round{
waitCh: make(chan struct{}),
brokeCh: make(chan struct{}),
},
}
}
|
Await
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
|
func (b *cyclicBarrier) Await(ctx context.Context) error {
var (
ctxDoneCh <-chan struct{}
)
if ctx != nil {
ctxDoneCh = ctx.Done()
}
// check if context is done
select {
case <-ctxDoneCh:
return ctx.Err()
default:
}
b.lock.Lock()
// check if broken
if b.round.isBroken {
b.lock.Unlock()
return ErrBrokenBarrier
}
// increment count of waiters
b.round.count++
// saving in local variables to prevent race
waitCh := b.round.waitCh
brokeCh := b.round.brokeCh
count := b.round.count
b.lock.Unlock()
if count > b.parties {
panic("CyclicBarrier.Await is called more than count of parties")
}
if count < b.parties {
// wait other parties
select {
case <-waitCh:
return nil
case <-brokeCh:
return ErrBrokenBarrier
case <-ctxDoneCh:
b.breakBarrier(true)
return ctx.Err()
}
} else {
// we are last, run the barrier action and reset the barrier
if b.barrierAction != nil {
err := b.barrierAction()
if err != nil {
b.breakBarrier(true)
return err
}
}
b.reset(true)
return nil
}
}
func (b *cyclicBarrier) reset(safe bool) {
b.lock.Lock()
defer b.lock.Unlock()
if safe {
// broadcast to pass waiting goroutines
close(b.round.waitCh)
} else if b.round.count > 0 {
b.breakBarrier(false)
}
// create new round
b.round = &round{
waitCh: make(chan struct{}),
brokeCh: make(chan struct{}),
}
}
func (b *cyclicBarrier) breakBarrier(needLock bool) {
if needLock {
b.lock.Lock()
defer b.lock.Unlock()
}
if !b.round.isBroken {
b.round.isBroken = true
// broadcast
close(b.round.brokeCh)
}
}
|
参考
本文内容摘录自:
- 极客专栏-鸟叔的 Go 并发编程实战