Golang package sync 剖析(二): sync.WaitGroup

1、前言

Go語言在設計上對同步(Synchronization,數據同步和線程同步)提供大量的支持,好比 goroutine和channel同步原語,庫層面有

  
- sync:提供基本的同步原語(好比Mutex、RWMutex、Locker)和 工具類(Once、WaitGroup、Cond、Pool、Map)
- sync/atomic:提供變量的原子操做(基於硬件指令 compare-and-swap)

-- 引用自《Golang package sync 剖析(一): sync.Once》golang

上一期中,咱們介紹了 sync.Once 如何保障 exactly once 語義,本期文章咱們介紹 package sync 下的另外一個工具類:sync.WaitGroup數據庫

2、爲何須要 WaitGroup

想象一個場景:咱們有一個用戶畫像服務,當一個請求到來時,須要c#

  1. 從 request 裏解析出 user_id 和 畫像維度參數
  2. 根據 user_id 從 ABCDE 五個子服務(數據庫服務、存儲服務、rpc服務等)拉取不一樣維度的信息
  3. 將讀取的信息進行整合,返回給調用方

假設 ABCDE 五個服務的響應時間 p99 是 20~50ms 之間。若是咱們順序調用 ABCDE 讀取信息,不考慮數據整合消耗時間,服務端總體響應時間 p99 是:segmentfault

sum(A, B, C, D, E) => [100ms, 250ms]

先不說業務上能不能接受,響應時間上顯然有很大的優化空間。最直觀的優化方向就是,取數邏輯的總時間消耗:微信

sum(A, B, C, D, E) -> max(A, B, C, D, E)

具體到 coding 上,咱們須要並行調用 ABCDE 五個子服務,待調用所有返回之後,進行數據整合。如何保障所有返回呢?函數

此時,sync.WaitGroup 閃耀登場。工具

3、WaitGroup 用法

官方文檔對 WaitGroup 的描述是:一個 WaitGroup 對象能夠等待一組協程結束。使用方法是:oop

  1. main協程經過調用 wg.Add(delta int) 設置worker協程的個數,而後建立worker協程;
  2. worker協程執行結束之後,都要調用 wg.Done()
  3. main協程調用 wg.Wait() 且被block,直到全部worker協程所有執行結束後返回。

這裏先看一個典型的例子:優化

// src/cmd/compile/internal/ssa/gen/main.go
func  main() {
  // 省略部分代碼 ...
  var wg sync.WaitGroup
  for _, task := range tasks {
    task  := task
    wg.Add(1)
    go func() {
      task()
      wg.Done()
    }()
  }
  wg.Wait()
  // 省略部分代碼...
}

這個例子具有了 WaitGroup 正確使用的大部分要素,包括:ui

  1. wg.Done 必須在全部 wg.Add 以後執行,因此要保證兩個函數都在main協程中調用;
  2. wg.Done 在 worker協程裏調用,尤爲要保證調用一次,不能由於 panic 或任何緣由致使沒有執行(建議使用 defer wg.Done());
  3. wg.Donewg.Wait 在時序上是沒有前後。

細心的朋友可能會發現一行很是詭異的代碼:

task  := task

Go 對 array/slice 進行遍歷時,runtime 會把 task[i] 拷貝到 task 的內存地址,下標 i 會變,而 task 的內存地址不會變。若是不進行此次賦值操做,全部 goroutine 可能讀到的都是最後一個task。爲了讓你們有一個直觀的感受,咱們用下面這段代碼作實驗:

package main

import (
  "fmt"
  "unsafe"
)

func main() {
  tasks := []func(){
    func() { fmt.Printf("1. ") },
    func() { fmt.Printf("2. ") },
  }

  for idx, task := range tasks {
    task()
    fmt.Printf("遍歷 = %v, ", unsafe.Pointer(&task))
    fmt.Printf("下標 = %v, ", unsafe.Pointer(&tasks[idx]))
    task  := task
    fmt.Printf("局部變量 = %v\\n", unsafe.Pointer(&task))
  }
}

這段代碼的打印結果是:

1. 遍歷 = 0x40c140, 下標 = 0x40c138, 局部變量 = 0x40c150
2. 遍歷 = 0x40c140, 下標 = 0x40c13c, 局部變量 = 0x40c158

不一樣機器上執行打印結果有所不一樣,但共同點是:

  1. 遍歷時,數據的內存地址不變
  2. 經過下標取數時,內存地址不一樣
  3. for-loop 內建立的局部變量,即使名字相同,內存地址也不會複用

使用 WaitGroup 時,除了上面提到的注意事項,還須要解決數據回收和異常處理的問題。這裏咱們也提供兩種方式供參考:

  1. 對於 rpc 調用,能夠經過 data channel 和 error channel 蒐集信息,或者二合一的channel
  2. 共享變量,好比加鎖的 map

4、WaitGroup 實現

在討論這個主題以前,建議讀者先思考一下:若是讓你去實現 WaitGroup,你會怎麼作?

鎖?確定不行!

信號量?怎麼實現?

------------切入正題------------

在 Go 源碼裏,WaitGroup 在邏輯上包含:

  1. worker 計數器:main協程調用 wg.Add(delta int) 時增長 delta,調用 wg.Done時減一。
  2. waiter 計數器:調用 wg.Wait 時,計數器加一; worker計數器下降到0時,重置waiter計數器
  3. 信號量:用於阻塞 main協程。調用 wg.Wait 時,經過 runtime_Semacquire 獲取信號量;下降 waiter 計數器時,經過 runtime_Semrelease 釋放信號量。

爲了便於演示,咱們魔改一下上面的例子:

package main

import (
  "fmt"
  "sync"
  "time"
)

func main() {
  tasks  := []func(){
    func() { time.Sleep(time.Second); fmt.Println("1 sec later") },
    func() { time.Sleep(time.Second *  2); fmt.Println("2 sec later") },
}

  var wg sync.WaitGroup // 1-1
  wg.Add(len(tasks))    // 1-2
  for _, task := range tasks {
    task  := task
    go func() {       // 1-3-1
      defer wg.Done() // 1-3-2
      task()          // 1-3-3
    }()               // 1-3-1
  }
  wg.Wait()           // 1-4
  fmt.Println("exit")
}

上面這段代碼中,

  1. 1-1 建立一個 WaitGroup 對象,worker計數器和waiter計數器默認值均爲0。
  2. 1-2 設置 worker計數器爲 len(tasks)
  3. 1-3-1 建立 worker協程,並啓動任務。
  4. 1-4 設置 waiter計數器,獲取信號量,main協程被阻塞。
  5. 1-3-3 中執行結束後,1-3-2 下降worker計數器。當worker計數器下降到0時,

    • 重置 waiter計數器
    • 釋放信號量,main 協程被激活,1-4 wg.Wait 返回

儘管 Add(delta int) 裏 delta 能夠是正數、0、負數。咱們在使用時,delta 老是正數。

wg.Done 等價於 wg.Add(-1)。在本文中,咱們提到 wg.Add時,默認 delta > 0

瞭解了 WaitGroup 的原理之後,咱們看下它的源碼。爲了便於理解,我只保留了核心邏輯。對於這部分邏輯,咱們分三部分講解:

  1. WaitGroup 結構
  2. AddDone
  3. Wait

提示:若是隻想了解 WaitGroup 的正確用法,本文讀到這兒就足夠了。對底層有興趣的朋友能夠繼續讀,不過最好打開IDE,參考源碼一塊兒讀。

4.1 WaitGroup 結構

type WaitGroup struct {
  noCopy noCopy
  state1 [3]uint32
}

WaitGroup 結構體裏有 noCopystate1 兩個字段。

編譯代碼時,go vet 工具會檢查 noCopy 字段,避免 WaitGroup 對象被拷貝。

state1 字段比較秀,在邏輯上它包含了 worker計數器、waiter計數器和信號量。具體如何讀這三個變量,參考下面代碼:

// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
  if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
    return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
  } else {
    return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
  }
}

// 讀取計數器和信號量
statep, semap := wg.state()
state  := atomic.LoadUint64(statep)
v := int32(state >> 32)
w := uint32(state)

三個變量的取數邏輯是:

  • worker計數器:vstatep *uint64左32位
  • waiter計數器:wstatep *uint64右32位
  • 信號量:semapstate1 [3]uint32 的第一個字節/最後一個字節

因此,更新worker計數器,須要這樣作:

state := atomic.AddUint64(statep, uint64(delta)<<32)

更新waiter計數器,須要這樣作:

statep, semap := wg.state()
for {
  state := atomic.LoadUint64(statep)
  if atomic.CompareAndSwapUint64(statep, state, state+1)   {
    // 忽略其餘邏輯
    return
  }
}

細心的朋友可能會發現,worker計數器的更新是直接累加,而 waiter計數器的更新是 CompareAndSwap。這是由於在 main協程中執行 wg.Add 時,只有main協程對 state1 作修改;而 wg.Wait 中修改waiter計數器時,可能有不少個協程在更新 state1。若是你還不太理解這段話,不妨先往下走,瞭解 wg.Addwg.Wait 的細節以後再回頭看。

4.2 Add 和 Done

wg.Add 操做的核心邏輯比較簡單,即修改 worker計數器,根據worker計數器的狀態進行後續操做。簡化版的代碼以下:

func (wg *WaitGroup) Add(delta int) {
  statep, semap := wg.state()
  // 1. 修改worker計數器
  state := atomic.AddUint64(statep, uint64(delta)<<32)
  v := int32(state >> 32)
  w := uint32(state)
  if v <  0 {
    panic("sync: negative WaitGroup counter")
  }
  if w != 0 && delta > 0 && v == int32(delta) {
    panic("sync: WaitGroup misuse: Add called concurrently with Wait")
  }
  // 2. 判斷計數器
  if v > 0 || w == 0 {
    return
  }
  
  // 3. 當 worker計數器下降到0時
  // 重置 waiter計數器,並釋放信號量
  *statep = 0
  for ; w != 0; w-- {
    runtime_Semrelease(semap, false)
  }
}

func (wg *WaitGroup) Done() {
  wg.Add(-1)
}

4.3 Wait

wg.Wait 的邏輯是修改waiter計數器,並等待信號量被釋放。簡化版的代碼以下:

func (wg *WaitGroup) Wait() {
  statep, semap  := wg.state()
  for {
    // 1. 讀取計數器
    state := atomic.LoadUint64(statep)
    v := int32(state >> 32)
    w := uint32(state)
    if v == 0 {
      return
    }

    // 2. 增長waiter計數器
    if atomic.CompareAndSwapUint64(statep, state, state+1) {
      // 3. 獲取信號量
      runtime_Semacquire(semap)
      if *statep != 0 {
        panic("sync: WaitGroup is reused before previous Wait has returned")
      }
    
      // 4. 信號量獲取成功
      return
    }
  }
}

因爲源碼比較長,包含了不少校驗邏輯和註釋,本文中在引用時,在保留核心邏輯的同時均作了不一樣程度的刪減。最後,推薦各位把源碼下載下來,細細研讀一番,從細節上對 WaitGroup 的設計有更深刻的理解。

References

  1. Golang: sync.WaitGroup

掃碼關注微信公衆號「深刻Go語言」

圖片描述

相關文章
相關標籤/搜索