目录

go CyclicBarrier 循环栅栏

1. CyclicBarrier 概述

CyclicBarrier 是一个可重用的栅栏并发原语,常常应用于重复进行一组 goroutine 同时执行的场景中。

CyclicBarrier允许一组 goroutine 彼此等待,到达一个共同的执行点。同时,因为它可以被重复使用,所以叫循环栅栏。具体的机制是,大家都在栅栏前等待,等全部都到齐了,就抬起栅栏放行。

1.1 CyclicBarrier 与 WaitGroup

你可能会觉得,CyclicBarrier 和 WaitGroup 的功能有点类似,确实是这样。不过还是有区别的:

  1. CyclicBarrier 更适合用在“固定数量的 goroutine 等待同一个执行点”的场景中,
  2. 而且在放行 goroutine 之后,CyclicBarrier 可以重复利用,
  3. 不像 WaitGroup 重用的时候,必须小心翼翼避免 panic。

处理可重用的多 goroutine 等待同一个执行点的场景的时候,CyclicBarrier 和 WaitGroup 方法调用的对应关系如下:

/images/go/sync/CyclicBarrier.jpg

如果使用 WaitGroup 实现的话,调用比较复杂,不像 CyclicBarrier 那么清爽。更重要的是,如果想重用 WaitGroup,你还要保证,将 WaitGroup 的计数值重置到 n 的时候不会出现并发问题。WaitGroup 更适合用在“一个 goroutine 等待一组 goroutine 到达同一个执行点”的场景中,或者是不需要重用的场景中。

1.2 CyclicBarrier 使用

CyclicBarrier 有两个初始化方法:

  1. 第一个是 New 方法,它只需要一个参数,来指定循环栅栏参与者的数量;
  2. 第二个方法是 NewWithAction
    • 它额外提供一个函数,可以在每一次到达执行点的时候执行一次
    • 执行具体的时间点是在最后一个参与者到达之后,但是其它的参与者还未被放行之前。我们可以利用它,做放行之前的一些共享状态的更新等操作。
1
2
3

func New(parties int) CyclicBarrier
func NewWithAction(parties int, barrierAction func() error) CyclicBarrier

CyclicBarrier 是一个接口,定义的方法如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17

type CyclicBarrier interface {
    // 等待所有的参与者到达,如果被ctx.Done()中断,会返回ErrBrokenBarrier
    Await(ctx context.Context) error

    // 重置循环栅栏到初始化状态。如果当前有等待者,那么它们会返回ErrBrokenBarrier
    Reset()

    // 返回当前等待者的数量
    GetNumberWaiting() int

    // 参与者的数量
    GetParties() int

    // 循环栅栏是否处于中断状态
    IsBroken() bool
}

循环栅栏的使用也很简单。循环栅栏的参与者只需调用 Await 等待,等所有的参与者都到达后,再执行下一步。当执行下一步的时候,循环栅栏的状态又恢复到初始的状态了,可以迎接下一轮同样多的参与者。下面是一个使用示例: 生产水原子,每生产一个水分子,就会打印出 HHO、HOH、OHH 三种形式的其中一种。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

package water
import (
  "context"
  "github.com/marusama/cyclicbarrier"
  "golang.org/x/sync/semaphore"
)
// 定义水分子合成的辅助数据结构
type H2O struct {
  semaH *semaphore.Weighted // 氢原子的信号量
  semaO *semaphore.Weighted // 氧原子的信号量
  b     cyclicbarrier.CyclicBarrier // 循环栅栏,用来控制合成
}
func New() *H2O {
  return &H2O{
    semaH: semaphore.NewWeighted(2), //氢原子需要两个
    semaO: semaphore.NewWeighted(1), // 氧原子需要一个
    b:     cyclicbarrier.New(3),  // 需要三个原子才能合成
  }
}


func (h2o *H2O) hydrogen(releaseHydrogen func()) {
  h2o.semaH.Acquire(context.Background(), 1)

  releaseHydrogen() // 输出H
  h2o.b.Await(context.Background()) //等待栅栏放行
  h2o.semaH.Release(1) // 释放氢原子空槽
}


func (h2o *H2O) oxygen(releaseOxygen func()) {
  h2o.semaO.Acquire(context.Background(), 1)

  releaseOxygen() // 输出O
  h2o.b.Await(context.Background()) //等待栅栏放行
  h2o.semaO.Release(1) // 释放氢原子空槽
}

下面是对应的单元测试

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

package water


import (
    "math/rand"
    "sort"
    "sync"
    "testing"
    "time"
)


func TestWaterFactory(t *testing.T) {
    //用来存放水分子结果的channel
    var ch chan string
    releaseHydrogen := func() {
        ch <- "H"
    }
    releaseOxygen := func() {
        ch <- "O"
    }

    // 300个原子,300个goroutine,每个goroutine并发的产生一个原子
    var N = 100
    ch = make(chan string, N*3)


    h2o := New()

    // 用来等待所有的goroutine完成
    var wg sync.WaitGroup
    wg.Add(N * 3)
   
    // 200个氢原子goroutine
    for i := 0; i < 2*N; i++ {
        go func() {
            time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
            h2o.hydrogen(releaseHydrogen)
            wg.Done()
        }()
    }
    // 100个氧原子goroutine
    for i := 0; i < N; i++ {
        go func() {
            time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond)
            h2o.oxygen(releaseOxygen)
            wg.Done()
        }()
    }
    
    //等待所有的goroutine执行完
    wg.Wait()

    // 结果中肯定是300个原子
    if len(ch) != N*3 {
        t.Fatalf("expect %d atom but got %d", N*3, len(ch))
    }

    // 每三个原子一组,分别进行检查。要求这一组原子中必须包含两个氢原子和一个氧原子,这样才能正确组成一个水分子。
    var s = make([]string, 3)
    for i := 0; i < N; i++ {
        s[0] = <-ch
        s[1] = <-ch
        s[2] = <-ch
        sort.Strings(s)


        water := s[0] + s[1] + s[2]
        if water != "HHO" {
            t.Fatalf("expect a water molecule but got %s", water)
        }
    }
}

如果你没有学习 CyclicBarrier,你可能只会想到,用 WaitGroup 来实现这个水分子制造工厂的例子。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

type H2O struct {
    semaH *semaphore.Weighted
    semaO *semaphore.Weighted
    wg    sync.WaitGroup //将循环栅栏替换成WaitGroup
}

func New() *H2O {
    var wg sync.WaitGroup
    wg.Add(3)

    return &H2O{
        semaH: semaphore.NewWeighted(2),
        semaO: semaphore.NewWeighted(1),
        wg:    wg,
    }
}


func (h2o *H2O) hydrogen(releaseHydrogen func()) {
    h2o.semaH.Acquire(context.Background(), 1)
    releaseHydrogen()

    // 标记自己已达到,等待其它goroutine到达
    h2o.wg.Done()
    h2o.wg.Wait()

    h2o.semaH.Release(1)
}

func (h2o *H2O) oxygen(releaseOxygen func()) {
    h2o.semaO.Acquire(context.Background(), 1)
    releaseOxygen()

    // 标记自己已达到,等待其它goroutine到达
    h2o.wg.Done()
    h2o.wg.Wait()
    //都到达后重置wg 
    h2o.wg.Add(3)

    h2o.semaO.Release(1)
}

使用 WaitGroup 非常复杂,而且,重用和 Done 方法的调用有并发的问题,程序可能 panic,远远没有使用循环栅栏更加简单直接。

1.3 CyclicBarrier 的实现

CyclicBarrier 的数据结构

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// round
type round struct {
	count    int           // count of goroutines for this roundtrip
	waitCh   chan struct{} // wait channel for this roundtrip
	brokeCh  chan struct{} // channel for isBroken broadcast
	isBroken bool          // is barrier broken
}

// cyclicBarrier impl CyclicBarrier intf
type cyclicBarrier struct {
	parties       int
	barrierAction func() error

	lock  sync.RWMutex
	round *round
}

// New initializes a new instance of the CyclicBarrier, specifying the number of parties.
func New(parties int) CyclicBarrier {
	if parties <= 0 {
		panic("parties must be positive number")
	}
	return &cyclicBarrier{
		parties: parties,
		lock:    sync.RWMutex{},
		round: &round{
			waitCh:  make(chan struct{}),
			brokeCh: make(chan struct{}),
		},
	}
}

Await

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
func (b *cyclicBarrier) Await(ctx context.Context) error {
	var (
		ctxDoneCh <-chan struct{}
	)
	if ctx != nil {
		ctxDoneCh = ctx.Done()
	}

	// check if context is done
	select {
	case <-ctxDoneCh:
		return ctx.Err()
	default:
	}

	b.lock.Lock()

	// check if broken
	if b.round.isBroken {
		b.lock.Unlock()
		return ErrBrokenBarrier
	}

	// increment count of waiters
	b.round.count++

	// saving in local variables to prevent race
	waitCh := b.round.waitCh
	brokeCh := b.round.brokeCh
	count := b.round.count

	b.lock.Unlock()

	if count > b.parties {
		panic("CyclicBarrier.Await is called more than count of parties")
	}

	if count < b.parties {
		// wait other parties
		select {
		case <-waitCh:
			return nil
		case <-brokeCh:
			return ErrBrokenBarrier
		case <-ctxDoneCh:
			b.breakBarrier(true)
			return ctx.Err()
		}
	} else {
		// we are last, run the barrier action and reset the barrier
		if b.barrierAction != nil {
			err := b.barrierAction()
			if err != nil {
				b.breakBarrier(true)
				return err
			}
		}
		b.reset(true)
		return nil
	}
}

func (b *cyclicBarrier) reset(safe bool) {
	b.lock.Lock()
	defer b.lock.Unlock()

	if safe {
		// broadcast to pass waiting goroutines
		close(b.round.waitCh)

	} else if b.round.count > 0 {
		b.breakBarrier(false)
	}

	// create new round
	b.round = &round{
		waitCh:  make(chan struct{}),
		brokeCh: make(chan struct{}),
	}
}

func (b *cyclicBarrier) breakBarrier(needLock bool) {
	if needLock {
		b.lock.Lock()
		defer b.lock.Unlock()
	}

	if !b.round.isBroken {
		b.round.isBroken = true

		// broadcast
		close(b.round.brokeCh)
	}
}

参考

本文内容摘录自:

  1. 极客专栏-鸟叔的 Go 并发编程实战