Go 的一個 CAS 操做使用場景

大概一年前,曾經遇到這麼一個問題:
程序中有 N個併發執行的routine,都會向一個size 爲 n 的 channel 裏面寫入數據,這 N 個 routine 有比較高的併發度,同時負載也比較大,因此不但願在寫入數據的時候卡住,所以使用了這樣的代碼。bash

if len(c) < n {
    c <- something // 寫入
}

本來意義是,保證必定可以寫入,防止 worker routine 卡住。但實際運行過程當中發生了routine 卡在 channel 寫入處的現象,緣由很是簡單:併發

多個 routine 同時判斷 len(c)沒有滿,而且同時進入了寫入 channel 的代碼,當 channel 滿了,處理若是不及時,那麼後寫入的routine 便會阻塞在此。

使用一個sync.Mutex把檢查長度和寫入 channel 的代碼保護起來固然能夠解決,可是因爲認爲這個 Mutex 可能影響性能,實際上使用的是一個比較low的辦法解決。app

const (
    _CHAN_SIZE  = 10
    _GUARD_SIZE = 10
)

var c chan interface{} = make(_CHAN_SIZE + _GUARD_SIZE) // 額外分配了一塊保護的空間。

func write(val interface{}) {
    if len(c) < _CHAN_SIZE {
        c <- val
    }
}

在併發執行的多個 routine R1,R2...Rn 的中,同一時間只容許惟一一個 routine 執行某一個操做,而且其餘 routine 須要非阻塞的知道本身無權操做並返回的時候,可使用 CAS 操做。函數

對於這些 worker routine來講,狀況大概是這樣的:oop

『弱弱的瞥一眼那個位置(操做),沒人佔着咱就佔,其餘人佔着咱也不等,直接走人』

比較優雅的方式,是使用go標準庫裏面的 atomic.CompareAndSwap 這一族函數。性能

// CompareAndSwapInt64 executes the compare-and-swap operation for an int64 value.
func CompareAndSwapInt64(addr *int64, old, new int64) (swapped bool)
...

這些函數功能很簡單,當給定的地址的值和 old 相等的時候,設置爲新值,同時返回true,不然返回false
該函數爲原子操做。測試

維基百科上的描述:
比較並交換(compare and swap, CAS)ui

因而上面的代碼能夠這麼寫:atom

func writeMsgWithCASCheck(val interface{}) {
    if atomic.CompareAndSwapInt64(&flag, 0, 1) {
        if len(c) < _CHAN_SIZE {
            c <- val
            atomic.StoreInt64(&obj.flag, 0)
            return nil
        }
        atomic.StoreInt64(&obj.flag, 0)
    }
}

若是要保證必定寫入進去的話,能夠在 atomic外面再套一個 for:code

func writeMsgWithCASCheck(val interface{}) {
    for {
        if atomic.CompareAndSwapInt64(&flag, 0, 1) {
            if len(c) < _CHAN_SIZE {
                ...
        }
    }
}

但這樣的效果就和直接卡在 c <- val同樣,還 佔滿了 cpu(處於忙等狀態)。

針對這種狀況我寫了一個簡單的測試程序:

$ go run cas.go
R(0)+1 R(0)+1 R(0)+1 R(0)+1 R(0)+1 R(0)+1 R(0)+1 R(2)+1 R(3)+1 R(1)+1 R(0)+1 R(1)+1 R(2)+1 R(3)+1 Chan overflow, len: 13.
quit.
$ go run cas.go cas
R(0)+1 R(0)+1 R(0)+1 R(0)+1 R(0)+1 R(0)+1 R(0)+1 R(3)+1 R(1)+1 R(2)+1 R(1)+1 R(0)+1 R(3)+1 R(2)+1 R(1)+1 R(3)+1 R(3)+1 R(3)+1 R(3)+1 R(1)+1 R(2)+1 R(2)+1 R(2)+1 R(3)+1 R(1)+1 R(2)+1 R(3)+1 R(1)+1 R(1)+1 R(2)+1 R(1)+1 R(2)+1 <nil>
quit.

開4個 routine 不停寫入的狀況下仍是很容易出現寫入超過預期size 的狀況的。

完整代碼以下cas.go

package main

import (
    "errors"
    "fmt"
    "os"
    "sync/atomic"
    "time"
)

const (
    _CHAN_SIZE  = 10
    _GUARD_SIZE = 10

    _TEST_CNT = 32
)

type Obj struct {
    flag int64
    c    chan interface{}
}

func (obj *Obj) readLoop() error {
    counter := _TEST_CNT
    for {
        time.Sleep(5 * time.Millisecond)
        if len(obj.c) > _CHAN_SIZE {
            return errors.New(fmt.Sprintf("Chan overflow, len: %v.", len(obj.c)))
        } else if len(obj.c) > 0 {
            <-obj.c
            counter--
        }
        if counter <= 0 {
            return nil
        }
    }
}

func (obj *Obj) writeMsg(idx int, v interface{}) (err error) {
    for {
        if len(obj.c) < _CHAN_SIZE {
            obj.c <- v
            fmt.Printf("R(%v)+1 ", idx)
            return nil
        }
    }
}

func (obj *Obj) writeMsgWithCASCheck(idx int, v interface{}) (err error) {
    for {
        if atomic.CompareAndSwapInt64(&obj.flag, 0, 1) {
            if len(obj.c) < _CHAN_SIZE {
                obj.c <- v
                atomic.StoreInt64(&obj.flag, 0)
                fmt.Printf("R(%v)+1 ", idx)
                return nil
            } else {
                atomic.StoreInt64(&obj.flag, 0)
            }
        }
    }

    return nil
}

func main() {
    useCAS := false
    if len(os.Args) > 1 && os.Args[1] == "cas" {
        useCAS = true
    }
    routineCnt := 4
    tryCnt := _TEST_CNT / routineCnt
    var obj = &Obj{c: make(chan interface{}, _CHAN_SIZE+_GUARD_SIZE)}

    for idx := 0; idx < routineCnt; idx++ {
        go func(nameIdx int) {
            for tryIdx := 0; tryIdx < tryCnt; tryIdx++ {
                if useCAS {
                    obj.writeMsgWithCASCheck(nameIdx, nil)
                } else {
                    obj.writeMsg(nameIdx, nil)
                }
            }
        }(idx)
    }

    // fmt.Println(casObj.readLoop())
    fmt.Println(obj.readLoop())
    fmt.Println("quit.")
}
相關文章
相關標籤/搜索