使用 golang 實現相似 pthread_barrier_t 語義的 barrier 對象

看到golang標準庫sync package WaitGroup 類型, 本覺得是golang 版本的 barrier 對象實現,看到文檔給出的使用示例:
golang

 var wg sync.WaitGroup
    var urls = []string{
            "http://www.golang.org/",
            "http://www.google.com/",
            "http://www.somestupidname.com/",
    }
    for _, url := range urls {           
            // Increment the WaitGroup counter.
            wg.Add(1)            
            // Launch a goroutine to fetch the URL.
            go func(url string) {                   
             // Decrement the counter when the goroutine completes.
                    defer wg.Done()                    
                    // Fetch the URL.
                    http.Get(url)
            }(url)
    }    
    // Wait for all HTTP fetches to complete.
    wg.Wait()

能夠看出WaitGroup 類型主要用於某個goroutine(調用Wait() 方法的那個),  等待個數不定goroutine(內部調用Done() 方法),數組

Add 方法對內部計數,添加或減小,Done方法實際上是Add(-1);oop

與pthread_barrier_t 有着語義上的差異,pthread_barrier_wait() 的調用者之間互相等待,就比如5名隊員(線程)參加跨欄比賽,使用 pthread_barrier_init 初始化最後一個參數爲5,  五個隊員都是好基友, 定了規矩, 無論誰先到欄杆, 都要等隊友,直到最後一名隊員跨過欄時,而後同一塊兒步點再次出發。下面時使用pthread_barrier_t 的簡單示例 5個線程,每一個線程擁有一個私有數組,及增量數字:fetch

#define _GNU_SOURCE 

#include <pthread.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#define NTHR 5
#define NARR 6
#define INLOOPS 1000
#define OUTLOOPS 10
#define err_abort(code,text) do { \
    char errbuf[128] = {0};         \
    fprintf (stderr, "%s at \"%s\":%d: %s\n", \
        (text), __FILE__, __LINE__, strerror_r(code,errbuf,128)); \
    abort (); \
} while (0)

typedef struct thrArg {
    pthread_t   tid;
    int         incr;
    int         arr[NARR];
}thrArg;

pthread_barrier_t   barrier;
thrArg  thrs[NTHR];

void *thrFunc (void *arg)
{
    thrArg *self = (thrArg*)arg;    
    int j, i, k, status;
    
    for (i = 0; i < OUTLOOPS; i++) {
        status = pthread_barrier_wait (&barrier);
        if (status > 0)
            err_abort (status, "wait on barrier");
        //每一個線程迭代 INLOOPS 次,對本身的內部數組arr 成員加上 本身的增量值
        for (j = 0; j < INLOOPS; j++)
            for (k = 0; k < NARR; k++)
                self->arr[k] += self->incr;
        //先執行完迭代的線程在此等待,直到最後一個到達
        status = pthread_barrier_wait (&barrier);
        if (status > 0)
            err_abort (status, "wait on barrier");
        //最後一個到達的線程,把全部線程的內部增量加1
        //此時其餘先到的線程阻塞在第一次wait調用處,因此最後一個到達的線程
        //能夠排他性地訪問全部線程的內部狀態,if 語句執行完後,跳到第一次wait處,
        //其餘阻塞在第一次wait處的線程,獲得釋放,你們一塊使用新的增量作計算
        if (status == PTHREAD_BARRIER_SERIAL_THREAD ) {
            int i;
            for (i = 0; i < NTHR; i++)
                thrs[i].incr += 1;
        }
    }
    return NULL;
}

int main (int arg, char *argv[])
{
    int i, j;
    int status;

    pthread_barrier_init (&barrier, NULL, NTHR);

    for (i = 0; i < NTHR; i++) {
        thrs[i].incr = i;
        for (j = 0; j < NARR; j++)
            thrs[i].arr[j] = j + 1;

        status = pthread_create (&thrs[i].tid,
            NULL, thrFunc, (void*)&thrs[i]);
        if (status != 0)
            err_abort (status, "create thread");
    }

    for (i = 0; i < NTHR; i++) {
        status = pthread_join (thrs[i].tid, NULL);
        if (status != 0)
            err_abort (status, "join thread");

        printf ("%02d: (%d) ", i, thrs[i].incr);

        for (j = 0; j < NARR; j++)
            printf ("%010u ", thrs[i].arr[j]);
        printf ("\n");
    }
    pthread_barrier_destroy (&barrier);
    return 0;
}

怎麼用golang 來表達上述c 代碼,須要實現pthread_barrier_t 等價語義的的 barrier 對象,可使用golang 已有的mutex, condgoogle

對象實現 barrier:url

package main
import (
    "fmt"
    "sync"
)
type Barrier struct{
    lock  sync.Mutex
    cond  sync.Cond
    threshold  int    //總的等待個數
    count      int    //還剩多少沒有到達barrier,即沒有完成wait調用個數
    cycle      bool   //用於重初始化下一個wait 週期,
}
func NewBarrier(n  int) *Barrier{
    b := &Barrier{threshold: n, count: n} 
    b.cond.L = &b.lock
    return b
}
//last == true ,說明最有一個到達
func (b *Barrier)Wait()(last bool){
    b.lock.Lock()
    defer  b.lock.Unlock()
    cycle :=  b.cycle
    b.count--
    //最後一個到達負責,重初始化count 計數,cycle 變量翻轉,
    if b.count == 0 {
       b.cycle  =  !b.cycle 
       b.count = b.threshold 
       b.cond.Broadcast()
       last = true
    }else{
      for cycle == b.cycle {
          b.cond.Wait()
      }
    }
    return
}
type thrArg struct{
   incr  int
   arr   [narr]int
}
var (
    thrs  [nthr]thrArg
    wg   sync.WaitGroup
    barrier = NewBarrier(nthr)
)
const (
    outloops = 10
    inloops  = 1000
    nthr  = 5
    narr  = 6
)

func thrFunc(arg  *thrArg){
    defer wg.Done()
    for i := 0; i < outloops; i++{
        barrier.Wait()
        for j := 0; j < inloops; j++{
            for k:= 0; k < narr; k++{
                arg.arr[k] += arg.incr
            }
        }
        if barrier.Wait() {
            for i := 0; i < nthr; i++{
                thrs[i].incr += 1
            }
        }
    }
}

func  main(){
    for i:= 0; i < nthr; i++{
        thrs[i].incr =  i
        for j := 0; j < narr; j++{
            thrs[i].arr[j] = j + 1
        }
        wg.Add(1)
        go thrFunc(&thrs[i])
    }
    wg.Wait()
    //全部goroutine完成,main goroutine,檢查最後的結果
    for i := 0; i < nthr; i++{
        fmt.Printf("%02d: (%d) ", i, thrs[i].incr)
        for j := 0; j < narr; j++{
            fmt.Printf ("%010d ", thrs[i].arr[j]);
        }
        fmt.Println()
    }
}
相關文章
相關標籤/搜索