C# Barrier 實現

當您須要一組任務並行地運行一連串的階段,可是每個階段都要等待全部其餘任務都完成前一階段以後才能開始,你一經過Barrier實例來同步這一類協同工做
Barrier初始化後,將等待特定數量的信號到來,這個數量在Barrier初始化時指定,在所指定的信號個數已經到來後,Barrier將執行一個指定的動做,這個動做也是在Barrier初始化時指定。Barrier在執行動做事後,將會重置,這時又將等待特定數量的信號到來,再執行指定動做。信號經過成員函數SignalAndWait()來發送,執行SignalAndWait()函數的Task或者線程將會投入等待,Barrier將等待特定數量的信號到達,而後Barrier執行完指定動做後被重置,這時SignalAndWait()函數所在的Task或者線程將繼續運行。在程序的運行過程當中,能夠經過成員函數AddParticipant()和RemoveParticpant()來增長或者減小須要等待的信號數量。讓咱們來看看Barrier實現:app

public class Barrier : IDisposable
{
    // The first 15 bits are for the total count which means the maximum participants for the barrier is about 32K
    // The 16th bit is dummy
    // The next 15th bit for the current
    // And the last highest bit is for the sense
    volatile int m_currentTotalCount;
    const int CURRENT_MASK = 0x7FFF0000;
    const int TOTAL_MASK = 0x00007FFF;
     // Bitmask to extratc the sense flag
    const int SENSE_MASK = unchecked((int)0x80000000);

    // The maximum participants the barrier can operate = 32767 ( 2 power 15 - 1 )
    const int MAX_PARTICIPANTS = TOTAL_MASK;
     long m_currentPhase;
    ManualResetEventSlim m_oddEvent;
    ManualResetEventSlim m_evenEvent;
    ExecutionContext m_ownerThreadContext;
    [SecurityCritical]
    private static ContextCallback s_invokePostPhaseAction;
    Action<Barrier> m_postPhaseAction;
    int m_actionCallerID;
    
    public Barrier(int participantCount): this(participantCount, null) {}
    
    public Barrier(int participantCount, Action<Barrier> postPhaseAction)
    {
        if (participantCount < 0 || participantCount > MAX_PARTICIPANTS)
        {
            throw new ArgumentOutOfRangeException("participantCount", participantCount, SR.GetString(SR.Barrier_ctor_ArgumentOutOfRange));
        }
        m_currentTotalCount = (int)participantCount;
        m_postPhaseAction = postPhaseAction;

        m_oddEvent = new ManualResetEventSlim(true);
        m_evenEvent = new ManualResetEventSlim(false);

        // Capture the context if the post phase action is not null
        if (postPhaseAction != null && !ExecutionContext.IsFlowSuppressed())
        {
        m_ownerThreadContext = ExecutionContext.Capture();
        }
        m_actionCallerID = 0;
    }
    //<returns>The phase number of the barrier in which the new participants will first participate.
    public long AddParticipant()
    {
        try
        {
            return AddParticipants(1);
        }
        catch (ArgumentOutOfRangeException)
        {
            throw new InvalidOperationException(SR.GetString(SR.Barrier_AddParticipants_Overflow_ArgumentOutOfRange));
        }
    }
    public long AddParticipants(int participantCount)
    {
        ThrowIfDisposed();
        if (participantCount < 1 )
        {
            throw new ArgumentOutOfRangeException("participantCount", participantCount, SR.GetString(SR.Barrier_AddParticipants_NonPositive_ArgumentOutOfRange));
        }
        else if (participantCount > MAX_PARTICIPANTS) //overflow
        {
            throw new ArgumentOutOfRangeException("participantCount", SR.GetString(SR.Barrier_AddParticipants_Overflow_ArgumentOutOfRange));
        }

        if (m_actionCallerID != 0 && Thread.CurrentThread.ManagedThreadId == m_actionCallerID)
        {
            throw new InvalidOperationException(SR.GetString(SR.Barrier_InvalidOperation_CalledFromPHA));
        }

        SpinWait spinner = new SpinWait();
        long newPhase = 0;
        while (true)
        {
            int currentTotal = m_currentTotalCount;
            int total;
            int current;
            bool sense;
            GetCurrentTotal(currentTotal, out current, out total, out sense);
            if (participantCount + total > MAX_PARTICIPANTS) //overflow
            {
                throw new ArgumentOutOfRangeException("participantCount",SR.GetString(SR.Barrier_AddParticipants_Overflow_ArgumentOutOfRange));
            }
            if (SetCurrentTotal(currentTotal, current, total + participantCount, sense))
            {
                long currPhase = CurrentPhaseNumber;
                newPhase = (sense != (currPhase % 2 == 0)) ? currPhase + 1 : currPhase;
                if (newPhase != currPhase)
                {
                    // Wait on the opposite event
                    if (sense)
                    {
                        m_oddEvent.Wait();
                    }
                    else { m_evenEvent.Wait(); }
                }
                else
                {
                    if (sense && m_evenEvent.IsSet)
                        m_evenEvent.Reset();
                    else if (!sense && m_oddEvent.IsSet)
                        m_oddEvent.Reset();
                }
                break;
            }
            spinner.SpinOnce();
        }
        return newPhase;
    }
    
    public void RemoveParticipant()
    {
        RemoveParticipants(1);
    }
    public void RemoveParticipants(int participantCount)
    {
        ThrowIfDisposed();
        if (participantCount < 1)
        {
            throw new ArgumentOutOfRangeException("participantCount", participantCount,SR.GetString(SR.Barrier_RemoveParticipants_NonPositive_ArgumentOutOfRange));
        }
        if (m_actionCallerID != 0 && Thread.CurrentThread.ManagedThreadId == m_actionCallerID)
        {
            throw new InvalidOperationException(SR.GetString(SR.Barrier_InvalidOperation_CalledFromPHA));
        }

        SpinWait spinner = new SpinWait();
        while (true)
        {
            int currentTotal = m_currentTotalCount;
            int total;
            int current;
            bool sense;
            GetCurrentTotal(currentTotal, out current, out total, out sense);

            if (total < participantCount)
            {
                throw new ArgumentOutOfRangeException("participantCount",SR.GetString(SR.Barrier_RemoveParticipants_ArgumentOutOfRange));
            }
            if (total - participantCount < current)
            {
                throw new InvalidOperationException(SR.GetString(SR.Barrier_RemoveParticipants_InvalidOperation));
            }
            // If the remaining participats = current participants, then finish the current phase
            int remaingParticipants = total - participantCount;
            if (remaingParticipants > 0 && current == remaingParticipants )
            {
                if (SetCurrentTotal(currentTotal, 0, total - participantCount, !sense))
                {
                    FinishPhase(sense);
                    break;
                }
            }
            else
            {
                if (SetCurrentTotal(currentTotal, current, total - participantCount, sense))
                {
                    break;
                }
            }
            spinner.SpinOnce();
        }
    }
    public void SignalAndWait()
    {
        SignalAndWait(new CancellationToken());
    }
    public void SignalAndWait(CancellationToken cancellationToken)
    {
        SignalAndWait(Timeout.Infinite, cancellationToken);
    }
    public bool SignalAndWait(int millisecondsTimeout, CancellationToken cancellationToken)
    {
        ThrowIfDisposed();
        cancellationToken.ThrowIfCancellationRequested();
        if (millisecondsTimeout < -1)
        {
            throw new System.ArgumentOutOfRangeException("millisecondsTimeout", millisecondsTimeout,SR.GetString(SR.Barrier_SignalAndWait_ArgumentOutOfRange));
        }
        if (m_actionCallerID != 0 && Thread.CurrentThread.ManagedThreadId == m_actionCallerID)
        {
            throw new InvalidOperationException(SR.GetString(SR.Barrier_InvalidOperation_CalledFromPHA));
        }
        bool sense; // The sense of the barrier *before* the phase associated with this SignalAndWait call completes
        int total;
        int current;
        int currentTotal;
        long phase;
        SpinWait spinner = new SpinWait();
        while (true)
        {
            currentTotal = m_currentTotalCount;
            GetCurrentTotal(currentTotal, out current, out total, out sense);
            phase = CurrentPhaseNumber;
            // throw if zero participants
            if (total == 0)
            {
                throw new InvalidOperationException(SR.GetString(SR.Barrier_SignalAndWait_InvalidOperation_ZeroTotal));
            }
            // Try to detect if the number of threads for this phase exceeded the total number of participants or not
            // This can be detected if the current is zero which means all participants for that phase has arrived and the phase number is not changed yet
            if (current == 0 && sense != (CurrentPhaseNumber % 2 == 0))
            {
                throw new InvalidOperationException(SR.GetString(SR.Barrier_SignalAndWait_InvalidOperation_ThreadsExceeded));
            }
            //This is the last thread, finish the phase
            if (current + 1 == total)
            {
                if (SetCurrentTotal(currentTotal, 0, total, !sense))
                {
                    FinishPhase(sense);
                    return true;
                }
            }
            else if (SetCurrentTotal(currentTotal, current + 1, total, sense))
            {
                break; }
            spinner.SpinOnce();
        }
        
        // ** Perform the real wait **
        // select the correct event to wait on, based on the current sense.
        ManualResetEventSlim eventToWaitOn = (sense) ? m_evenEvent : m_oddEvent;

        bool waitWasCanceled = false;
        bool waitResult = false;
        try
        {
        waitResult = DiscontinuousWait(eventToWaitOn, millisecondsTimeout, cancellationToken, phase);
        }
        catch (OperationCanceledException )
        {
            waitWasCanceled = true;
        }
        catch (ObjectDisposedException)// in case a ---- happen where one of the thread returned from SignalAndWait and the current thread calls Wait on a disposed event
        {
            // make sure the current phase for this thread is already finished, otherwise propagate the exception
            if (phase < CurrentPhaseNumber) 
                waitResult = true;
            else
                throw;
        }
        if (!waitResult)
        {
            //reset the spinLock to prepare it for the next loop
            spinner.Reset();

            //If the wait timeout expired and all other thread didn't reach the barrier yet, update the current count back
            while (true)
            {
                bool newSense;
                currentTotal = m_currentTotalCount;
                GetCurrentTotal(currentTotal, out current, out total, out newSense);
                // If the timeout expired and the phase has just finished, return true and this is considered as succeeded SignalAndWait
                //otherwise the timeout expired and the current phase has not been finished yet, return false
                //The phase is finished if the phase member variable is changed (incremented) or the sense has been changed
                // we have to use the statements in the comparison below for two cases:
                // 1- The sense is changed but the last thread didn't update the phase yet
                // 2- The phase is already incremented but the sense flipped twice due to the termination of the next phase
                if (phase < CurrentPhaseNumber || sense != newSense)
                {

                    // The current phase has been finished, but we shouldn't return before the events are set/reset otherwise this thread could start
                    // next phase and the appropriate event has not reset yet which could make it return immediately from the next phase SignalAndWait
                    // before waiting other threads
 WaitCurrentPhase(eventToWaitOn, phase);
                    Debug.Assert(phase < CurrentPhaseNumber);
                    break;
                }
                //The phase has not been finished yet, try to update the current count.
                if (SetCurrentTotal(currentTotal, current - 1, total, sense))
                {
                    //if here, then the attempt to backout was successful.
                    //throw (a fresh) oce if cancellation woke the wait
                    //or return false if it was the timeout that woke the wait.
                    //
                    if (waitWasCanceled)
                        throw new OperationCanceledException(SR.GetString(SR.Common_OperationCanceled), cancellationToken);
                    else
                        return false;
                }
                spinner.SpinOnce();
            }
        }

        if (m_exception != null)
            throw new BarrierPostPhaseException(m_exception);

        return true;

    }
    
    private void FinishPhase(bool observedSense)
    {
        // Execute the PHA in try/finally block to reset the variables back in case of it threw an exception
        if (m_postPhaseAction != null)
        {
            try
            {
                m_actionCallerID = Thread.CurrentThread.ManagedThreadId;
                if (m_ownerThreadContext != null)
                {
                    var currentContext = m_ownerThreadContext;
                    m_ownerThreadContext = m_ownerThreadContext.CreateCopy(); // create a copy for the next run

                    ContextCallback handler = s_invokePostPhaseAction;
                    if (handler == null)
                    {
                        s_invokePostPhaseAction = handler = InvokePostPhaseAction;
                    }
                    ExecutionContext.Run(currentContext, handler, this); currentContext.Dispose();
                }
                else
                {
                    m_postPhaseAction(this);
                }
                m_exception = null; // reset the exception if it was set previously
            }
            catch (Exception ex)
            {
                m_exception = ex;
            }
            finally
            {
                m_actionCallerID = 0;
               SetResetEvents(observedSense); if(m_exception != null)
                    throw new BarrierPostPhaseException(m_exception);
            }
        }
        else
        {
         SetResetEvents(observedSense);
        }
    }
    private void SetResetEvents(bool observedSense)
    {
        // Increment the phase count using Volatile class because m_currentPhase is 64 bit long type, that could cause torn write on 32 bit machines
        CurrentPhaseNumber = CurrentPhaseNumber + 1;
        if (observedSense)
        {
            m_oddEvent.Reset();
            m_evenEvent.Set();
        }
        else
        {
            m_evenEvent.Reset();
            m_oddEvent.Set();
        }
    }
    //<returns>True if the event is set or the phasenumber changed, false if the timeout expired
    private bool DiscontinuousWait(ManualResetEventSlim currentPhaseEvent, int totalTimeout, CancellationToken token, long observedPhase)
    {
        int maxWait = 100; // 100 ms
        int waitTimeCeiling = 10000; // 10 seconds
        while (observedPhase == CurrentPhaseNumber)
        {
            // the next wait time, the min of the maxWait and the totalTimeout
            int waitTime = totalTimeout == Timeout.Infinite ? maxWait : Math.Min(maxWait, totalTimeout);
            if (currentPhaseEvent.Wait(waitTime, token)) return true;

            //update the total wait time
            if (totalTimeout != Timeout.Infinite)
            {
                totalTimeout -= waitTime;
                if (totalTimeout <= 0) return false;
            }

            //if the maxwait exceeded 10 seconds then we will stop increasing the maxWait time and keep it 10 seconds, otherwise keep doubling it
            maxWait = maxWait >= waitTimeCeiling ? waitTimeCeiling : Math.Min(maxWait << 1, waitTimeCeiling);
        }

        //if we exited the loop because the observed phase doesn't match the current phase, then we have to spin to mske sure
        //the event is set or the next phase is finished
        WaitCurrentPhase(currentPhaseEvent, observedPhase);
        return true;
    }
    private void WaitCurrentPhase(ManualResetEventSlim currentPhaseEvent, long observedPhase)
    {
        //spin until either of these two conditions succeeds
        //1- The event is set
        //2- the phase count is incremented more than one time, this means the next phase is finished as well,
        //but the event will be reset again, so we check the phase count instead
        SpinWait spinner = new SpinWait();
        while (!currentPhaseEvent.IsSet && CurrentPhaseNumber - observedPhase <= 1)
        {
            spinner.SpinOnce();
        }
    }
    private static void InvokePostPhaseAction(object obj)
    {
        var thisBarrier = (Barrier)obj;
       thisBarrier.m_postPhaseAction(thisBarrier);
    }
    private bool SetCurrentTotal(int currentTotal, int current, int total, bool sense)
    {
        int newCurrentTotal = (current <<16) | total;           
        if (!sense)
        {
            newCurrentTotal |= SENSE_MASK;
        }
        return Interlocked.CompareExchange(ref m_currentTotalCount, newCurrentTotal, currentTotal) == currentTotal;
    }
    //Gets the total number of participants in the barrier.
    public int ParticipantCount
    {
        get { return (int)(m_currentTotalCount & TOTAL_MASK); }
    }
    public long CurrentPhaseNumber
    {
        // use the new Volatile.Read/Write method because it is cheaper than Interlocked.Read on AMD64 architecture
        get { return Volatile.Read(ref m_currentPhase); }

        internal set { Volatile.Write(ref m_currentPhase, value); }
    }
}

這裏邊有幾個變量須要說明一下,m_currentTotalCount,1-15存的是總的參與者總數,17-31是存的當前的參與者數量,32表示全部參與者是否都已到達,也就是後面判斷執行ManualResetEventSlim的那個實例m_oddEvent仍是m_evenEvent,Barrier得構造函數就不說了,若是指定了postPhaseAction,而且當前有能夠捕獲當前線程的上下文,那麼咱們須要捕獲當前上下文【m_ownerThreadContext = ExecutionContext.Capture()】,便於後面調用postPhaseAction。還有就是【m_oddEvent = new ManualResetEventSlim(true);m_evenEvent = new ManualResetEventSlim(false);】ide

AddParticipants表示增長總的參與者數目,那麼RemoveParticipants就是減小總的參與者數目,它們都是借用SpinWait的自旋和原子操做完成的,AddParticipants由於增長了總的參與者,因此一般須要調用ManualResetEventSlim的Wait方法【沒有完成的狀況下】,RemoveParticipants是減小參與者,那麼【current==total】可能減小後程序就該觸發結束標記了,這裏調用FinishPhase,否者就只是減小total的值。若是咱們先前的構造函數有回調,那麼這裏須要調用回調函數,若是先前捕獲了線程上線文那麼而回調須要傳入線程上下文【ExecutionContext.Run(currentContext, InvokePostPhaseAction, this);】否者只是簡單的方法調用【m_postPhaseAction(this)】函數

如今咱們再來看SignalAndWait方法,SignalAndWait方法也是藉助SpinWait的自旋和原子操做完成的,其核心操做 就是current=current+1, 若是current==total 那麼就調用FinishPhase,FinishPhase中會調用回調函數s_invokePostPhaseAction,以及發出Set信號,若是調用SignalAndWait方法後,current<total,那麼這裏繼續往下面執行,調用DiscontinuousWait方法阻塞當前任務【方法】,直到其餘任務調用SignalAndWait 方法【current==total時調用FinishPhase方法,發現胡Set信號】 。有關Barrier的使用在一本pdf裏面發現一個比較好的圖片:oop

相關文章
相關標籤/搜索