看到golang標準庫sync package WaitGroup 類型, 本覺得是golang 版本的 barrier 對象實現,看到文檔給出的使用示例:
golang
var wg sync.WaitGroup var urls = []string{ "http://www.golang.org/", "http://www.google.com/", "http://www.somestupidname.com/", } for _, url := range urls { // Increment the WaitGroup counter. wg.Add(1) // Launch a goroutine to fetch the URL. go func(url string) { // Decrement the counter when the goroutine completes. defer wg.Done() // Fetch the URL. http.Get(url) }(url) } // Wait for all HTTP fetches to complete. wg.Wait()
能夠看出WaitGroup 類型主要用於某個goroutine(調用Wait() 方法的那個), 等待個數不定goroutine(內部調用Done() 方法),數組
Add 方法對內部計數,添加或減小,Done方法實際上是Add(-1);oop
與pthread_barrier_t 有着語義上的差異,pthread_barrier_wait() 的調用者之間互相等待,就比如5名隊員(線程)參加跨欄比賽,使用 pthread_barrier_init 初始化最後一個參數爲5, 五個隊員都是好基友, 定了規矩, 無論誰先到欄杆, 都要等隊友,直到最後一名隊員跨過欄時,而後同一塊兒步點再次出發。下面時使用pthread_barrier_t 的簡單示例 5個線程,每一個線程擁有一個私有數組,及增量數字:fetch
#define _GNU_SOURCE #include <pthread.h> #include <stdio.h> #include <string.h> #include <stdlib.h> #define NTHR 5 #define NARR 6 #define INLOOPS 1000 #define OUTLOOPS 10 #define err_abort(code,text) do { \ char errbuf[128] = {0}; \ fprintf (stderr, "%s at \"%s\":%d: %s\n", \ (text), __FILE__, __LINE__, strerror_r(code,errbuf,128)); \ abort (); \ } while (0) typedef struct thrArg { pthread_t tid; int incr; int arr[NARR]; }thrArg; pthread_barrier_t barrier; thrArg thrs[NTHR]; void *thrFunc (void *arg) { thrArg *self = (thrArg*)arg; int j, i, k, status; for (i = 0; i < OUTLOOPS; i++) { status = pthread_barrier_wait (&barrier); if (status > 0) err_abort (status, "wait on barrier"); //每一個線程迭代 INLOOPS 次,對本身的內部數組arr 成員加上 本身的增量值 for (j = 0; j < INLOOPS; j++) for (k = 0; k < NARR; k++) self->arr[k] += self->incr; //先執行完迭代的線程在此等待,直到最後一個到達 status = pthread_barrier_wait (&barrier); if (status > 0) err_abort (status, "wait on barrier"); //最後一個到達的線程,把全部線程的內部增量加1 //此時其餘先到的線程阻塞在第一次wait調用處,因此最後一個到達的線程 //能夠排他性地訪問全部線程的內部狀態,if 語句執行完後,跳到第一次wait處, //其餘阻塞在第一次wait處的線程,獲得釋放,你們一塊使用新的增量作計算 if (status == PTHREAD_BARRIER_SERIAL_THREAD ) { int i; for (i = 0; i < NTHR; i++) thrs[i].incr += 1; } } return NULL; } int main (int arg, char *argv[]) { int i, j; int status; pthread_barrier_init (&barrier, NULL, NTHR); for (i = 0; i < NTHR; i++) { thrs[i].incr = i; for (j = 0; j < NARR; j++) thrs[i].arr[j] = j + 1; status = pthread_create (&thrs[i].tid, NULL, thrFunc, (void*)&thrs[i]); if (status != 0) err_abort (status, "create thread"); } for (i = 0; i < NTHR; i++) { status = pthread_join (thrs[i].tid, NULL); if (status != 0) err_abort (status, "join thread"); printf ("%02d: (%d) ", i, thrs[i].incr); for (j = 0; j < NARR; j++) printf ("%010u ", thrs[i].arr[j]); printf ("\n"); } pthread_barrier_destroy (&barrier); return 0; }
怎麼用golang 來表達上述c 代碼,須要實現pthread_barrier_t 等價語義的的 barrier 對象,可使用golang 已有的mutex, condgoogle
對象實現 barrier:url
package main import ( "fmt" "sync" ) type Barrier struct{ lock sync.Mutex cond sync.Cond threshold int //總的等待個數 count int //還剩多少沒有到達barrier,即沒有完成wait調用個數 cycle bool //用於重初始化下一個wait 週期, } func NewBarrier(n int) *Barrier{ b := &Barrier{threshold: n, count: n} b.cond.L = &b.lock return b } //last == true ,說明最有一個到達 func (b *Barrier)Wait()(last bool){ b.lock.Lock() defer b.lock.Unlock() cycle := b.cycle b.count-- //最後一個到達負責,重初始化count 計數,cycle 變量翻轉, if b.count == 0 { b.cycle = !b.cycle b.count = b.threshold b.cond.Broadcast() last = true }else{ for cycle == b.cycle { b.cond.Wait() } } return } type thrArg struct{ incr int arr [narr]int } var ( thrs [nthr]thrArg wg sync.WaitGroup barrier = NewBarrier(nthr) ) const ( outloops = 10 inloops = 1000 nthr = 5 narr = 6 ) func thrFunc(arg *thrArg){ defer wg.Done() for i := 0; i < outloops; i++{ barrier.Wait() for j := 0; j < inloops; j++{ for k:= 0; k < narr; k++{ arg.arr[k] += arg.incr } } if barrier.Wait() { for i := 0; i < nthr; i++{ thrs[i].incr += 1 } } } } func main(){ for i:= 0; i < nthr; i++{ thrs[i].incr = i for j := 0; j < narr; j++{ thrs[i].arr[j] = j + 1 } wg.Add(1) go thrFunc(&thrs[i]) } wg.Wait() //全部goroutine完成,main goroutine,檢查最後的結果 for i := 0; i < nthr; i++{ fmt.Printf("%02d: (%d) ", i, thrs[i].incr) for j := 0; j < narr; j++{ fmt.Printf ("%010d ", thrs[i].arr[j]); } fmt.Println() } }