C# CancellationTokenSource和CancellationToken的實現

微軟關於CancellationTokenSource的介紹很簡單,其實CancellationTokenSource的使用也很簡單,可是實現就不是那麼簡單了,咱們首先來看看CancellationTokenSource的實現:數組

public class CancellationTokenSource : IDisposable
{
    private const int CANNOT_BE_CANCELED = 0;
    private const int NOT_CANCELED = 1;
    private const int NOTIFYING = 2;
    private const int NOTIFYINGCOMPLETE = 3;
    
    private volatile int m_state;
    private static readonly Action<object> s_LinkedTokenCancelDelegate = new Action<object>(LinkedTokenCancelDelegate);    
    private static readonly int s_nLists = (PlatformHelper.ProcessorCount > 24) ? 24 : PlatformHelper.ProcessorCount; 
    private volatile CancellationCallbackInfo m_executingCallback;
    private volatile SparselyPopulatedArray<CancellationCallbackInfo>[] m_registeredCallbacksLists;
    private static readonly TimerCallback s_timerCallback = new TimerCallback(TimerCallbackLogic);
    private volatile Timer m_timer;
    
    public CancellationTokenSource()
    {
        m_state = NOT_CANCELED;
    }
    
    //Constructs a CancellationTokenSource that will be canceled after a specified time span.
    public CancellationTokenSource(Int32 millisecondsDelay)
    {
        if (millisecondsDelay < -1)
        {
            throw new ArgumentOutOfRangeException("millisecondsDelay");
        }

        InitializeWithTimer(millisecondsDelay);
    }
    
    private void InitializeWithTimer(Int32 millisecondsDelay)
    {
        m_state = NOT_CANCELED;
        m_timer = new Timer(s_timerCallback, this, millisecondsDelay, -1);
    }
    
    private static void TimerCallbackLogic(object obj)
    {
        CancellationTokenSource cts = (CancellationTokenSource)obj;
        if (!cts.IsDisposed)
        {            
            try
            {
              cts.Cancel(); // will take care of disposing of m_timer
            }
            catch (ObjectDisposedException)
            {
                if (!cts.IsDisposed) throw;
            }
        }
    }
    
    public void Cancel()
    {
        Cancel(false);
    }

    public void Cancel(bool throwOnFirstException)
    {
        ThrowIfDisposed();
       NotifyCancellation(throwOnFirstException);          
    }
    
    public void CancelAfter(Int32 millisecondsDelay)
    {
        ThrowIfDisposed();

        if (millisecondsDelay < -1)
        {
            throw new ArgumentOutOfRangeException("millisecondsDelay");
        }

        if (IsCancellationRequested) return;
        if (m_timer == null)
        {
            Timer newTimer = new Timer(s_timerCallback, this, -1, -1);
            if (Interlocked.CompareExchange(ref m_timer, newTimer, null) != null)
            {
                newTimer.Dispose();
            }
        }
        
        // It is possible that m_timer has already been disposed, so we must do
        // the following in a try/catch block.
        try
        {
            m_timer.Change(millisecondsDelay, -1);         }
        catch (ObjectDisposedException)
        {        
        }
    }
    
   
    private void NotifyCancellation(bool throwOnFirstException)
    {
        if (IsCancellationRequested)
            return;

        // If we're the first to signal cancellation, do the main extra work.
        if (Interlocked.CompareExchange(ref m_state, NOTIFYING, NOT_CANCELED) == NOT_CANCELED)
        {
            Timer timer = m_timer;
            if(timer != null) timer.Dispose();

            //record the threadID being used for running the callbacks.
            ThreadIDExecutingCallbacks = Thread.CurrentThread.ManagedThreadId;
            
            //If the kernel event is null at this point, it will be set during lazy construction.
            if (m_kernelEvent != null)
                m_kernelEvent.Set(); // update the MRE value.
 ExecuteCallbackHandlers(throwOnFirstException);
            Contract.Assert(IsCancellationCompleted, "Expected cancellation to have finished");
        }
    }
    
    /// Invoke the Canceled event. The handlers are invoked synchronously in LIFO order.
    private void ExecuteCallbackHandlers(bool throwOnFirstException)
    {
        Contract.Assert(IsCancellationRequested, "ExecuteCallbackHandlers should only be called after setting IsCancellationRequested->true");
        Contract.Assert(ThreadIDExecutingCallbacks != -1, "ThreadIDExecutingCallbacks should have been set.");

        List<Exception> exceptionList = null;
        SparselyPopulatedArray<CancellationCallbackInfo>[] callbackLists = m_registeredCallbacksLists;

        if (callbackLists == null)
        {
            Interlocked.Exchange(ref m_state, NOTIFYINGCOMPLETE);
            return;
        }
        
        try
        {
            for (int index = 0; index < callbackLists.Length; index++)
            {
                SparselyPopulatedArray<CancellationCallbackInfo> list = Volatile.Read<SparselyPopulatedArray<CancellationCallbackInfo>>(ref callbackLists[index]);
                if (list != null)
                {
                    SparselyPopulatedArrayFragment<CancellationCallbackInfo> currArrayFragment = list.Tail;

                    while (currArrayFragment != null)
                    {
                        for (int i = currArrayFragment.Length - 1; i >= 0; i--)
                        {
                            m_executingCallback = currArrayFragment[i];
                            if (m_executingCallback != null)
                            {
                                CancellationCallbackCoreWorkArguments args = new CancellationCallbackCoreWorkArguments(currArrayFragment, i);
                                try
                                {
                                    if (m_executingCallback.TargetSyncContext != null)
                                    {
                                     m_executingCallback.TargetSyncContext.Send(CancellationCallbackCoreWork_OnSyncContext, args);                                
                                        ThreadIDExecutingCallbacks = Thread.CurrentThread.ManagedThreadId;
                                    }
                                    else
                                    {
                                     CancellationCallbackCoreWork(args);
                                    }
                                }
                                catch(Exception ex)
                                {
                                    if (throwOnFirstException)
                                        throw;
                                    if(exceptionList == null)
                                        exceptionList = new List<Exception>();
                                    exceptionList.Add(ex);
                                }
                            }
                        }
                        currArrayFragment = currArrayFragment.Prev;
                    }
                }
            }
        }
        finally
        {
            m_state = NOTIFYINGCOMPLETE;
            m_executingCallback = null;
         Thread.MemoryBarrier(); // for safety, prevent reorderings crossing this point and seeing inconsistent state.
        }

        if (exceptionList != null)
        {
            Contract.Assert(exceptionList.Count > 0, "Expected exception count > 0");
            throw new AggregateException(exceptionList);
        }
    }
    
    private void CancellationCallbackCoreWork_OnSyncContext(object obj)
    {
        CancellationCallbackCoreWork((CancellationCallbackCoreWorkArguments)obj);
    }

    private void CancellationCallbackCoreWork(CancellationCallbackCoreWorkArguments args)
    {
        CancellationCallbackInfo callback = args.m_currArrayFragment.SafeAtomicRemove(args.m_currArrayIndex, m_executingCallback);
        if (callback == m_executingCallback)
        {
            if (callback.TargetExecutionContext != null)
            {
                callback.CancellationTokenSource.ThreadIDExecutingCallbacks = Thread.CurrentThread.ManagedThreadId;
            }
           callback.ExecuteCallback();
        }
    }
    
    public static CancellationTokenSource CreateLinkedTokenSource(CancellationToken token1, CancellationToken token2)
    {
        CancellationTokenSource linkedTokenSource = new CancellationTokenSource();
        bool token2CanBeCanceled = token2.CanBeCanceled;

        if( token1.CanBeCanceled )
        {
            linkedTokenSource.m_linkingRegistrations = new CancellationTokenRegistration[token2CanBeCanceled ? 2 : 1]; // there will be at least 1 and at most 2 linkings
            linkedTokenSource.m_linkingRegistrations[0] = token1.InternalRegisterWithoutEC(s_LinkedTokenCancelDelegate, linkedTokenSource);
        }
        
        if( token2CanBeCanceled )
        {
            int index = 1;
            if( linkedTokenSource.m_linkingRegistrations == null )
            {
                linkedTokenSource.m_linkingRegistrations = new CancellationTokenRegistration[1]; // this will be the only linking
                index = 0;
            }
            linkedTokenSource.m_linkingRegistrations[index] = token2.InternalRegisterWithoutEC(s_LinkedTokenCancelDelegate, linkedTokenSource);
        }        
        return linkedTokenSource;
    }
    
    public static CancellationTokenSource CreateLinkedTokenSource(params CancellationToken[] tokens)
    {
        if (tokens == null)
            throw new ArgumentNullException("tokens");

        if (tokens.Length == 0)
            throw new ArgumentException(Environment.GetResourceString("CancellationToken_CreateLinkedToken_TokensIsEmpty"));
            
        Contract.EndContractBlock();

        CancellationTokenSource linkedTokenSource = new CancellationTokenSource();
        linkedTokenSource.m_linkingRegistrations = new CancellationTokenRegistration[tokens.Length];

        for (int i = 0; i < tokens.Length; i++)
        {
            if (tokens[i].CanBeCanceled)
            {
                linkedTokenSource.m_linkingRegistrations[i] = tokens[i].InternalRegisterWithoutEC(s_LinkedTokenCancelDelegate, linkedTokenSource);
            }        
        }
        return linkedTokenSource;
    }
    
    internal CancellationTokenRegistration InternalRegister(Action<object> callback, object stateForCallback, SynchronizationContext targetSyncContext, ExecutionContext executionContext)
    {
        if (AppContextSwitches.ThrowExceptionIfDisposedCancellationTokenSource)
        {
            ThrowIfDisposed();
        }
        Contract.Assert(CanBeCanceled, "Cannot register for uncancelable token src");
        if (!IsCancellationRequested)
        {
            if (m_disposed && !AppContextSwitches.ThrowExceptionIfDisposedCancellationTokenSource)
                return new CancellationTokenRegistration();

            int myIndex = Thread.CurrentThread.ManagedThreadId % s_nLists;

            CancellationCallbackInfo callbackInfo = new CancellationCallbackInfo(callback, stateForCallback, targetSyncContext, executionContext, this);

            //allocate the callback list array
            var registeredCallbacksLists = m_registeredCallbacksLists;
            if (registeredCallbacksLists == null)
            {
                SparselyPopulatedArray<CancellationCallbackInfo>[] list = new SparselyPopulatedArray<CancellationCallbackInfo>[s_nLists];
               registeredCallbacksLists = Interlocked.CompareExchange(ref m_registeredCallbacksLists, list, null);
                if (registeredCallbacksLists == null) registeredCallbacksLists = list;
            }

            //allocate the actual lists on-demand to save mem in low-use situations, and to avoid false-sharing.
            var callbacks = Volatile.Read<SparselyPopulatedArray<CancellationCallbackInfo>>(ref registeredCallbacksLists[myIndex]);
            if (callbacks == null)
            {
                SparselyPopulatedArray<CancellationCallbackInfo> callBackArray = new SparselyPopulatedArray<CancellationCallbackInfo>(4);
                Interlocked.CompareExchange(ref (registeredCallbacksLists[myIndex]), callBackArray, null);
                callbacks = registeredCallbacksLists[myIndex];
            }

            // Now add the registration to the list.
            SparselyPopulatedArrayAddInfo<CancellationCallbackInfo> addInfo = callbacks.Add(callbackInfo);
            CancellationTokenRegistration registration = new CancellationTokenRegistration(callbackInfo, addInfo);

            if (!IsCancellationRequested)
                return registration;

            bool deregisterOccurred = registration.TryDeregister();

            if (!deregisterOccurred)
            {
                return registration;
            }
        }
        // If cancellation already occurred, we run the callback on this thread and return an empty registration.
        callback(stateForCallback);
        return new CancellationTokenRegistration();
    }        
    
    public bool IsCancellationRequested
    {
        get { return m_state >= NOTIFYING; }
    }
    
    internal bool IsCancellationCompleted
    {
        get { return m_state == NOTIFYINGCOMPLETE; }
    }
    
    public CancellationToken Token
    {
        get
        {
            ThrowIfDisposed();
            return new CancellationToken(this);
        }
    }
    internal CancellationCallbackInfo ExecutingCallback
    {
        get { return m_executingCallback; }
    }

   private static void LinkedTokenCancelDelegate(object source)
    {
        CancellationTokenSource cts = source as CancellationTokenSource;
        Contract.Assert(source != null);
        cts.Cancel();
    }
}

CancellationTokenSource的實現相對比較複雜,咱們首先看看CancellationTokenSource的構造函數,默認構造函數將會設置【m_state = NOT_CANCELED】,咱們也能夠構造一個特定時間後就自動Cancel的CancellationTokenSource,自動Cancel是依賴一個Timer實例,在Timer到指定時間後調用CancellationTokenSource的Cancel方法【這裏是在TimerCallbackLogic裏面調用Cancel方法】,CancelAfter方法的實現也是依賴這個Timer實例和TimerCallbackLogic方法ide

如今咱們來看看CancellationTokenSource最主要的一個方法Cancel,Cancel方法調用NotifyCancellation方法,NotifyCancellation方法主要調用ExecuteCallbackHandlers【從這個方法的名稱能夠猜想到主要是調用回調方法】,在ExecuteCallbackHandlers方法裏面用到一個變量m_registeredCallbacksLists,它是SparselyPopulatedArray<CancellationCallbackInfo>[]結構,【能夠理解爲是一個鏈表的數組,數組每一個元素時一個鏈表,鏈表裏面的每一個節點均可以訪問下一個節點】,咱們遍歷這個鏈表數組的每個節點,檢查節點是否有值,即m_executingCallback != null,而後調用回調方法,若是回調方法的TargetSyncContext不爲空,調用CancellationCallbackCoreWork_OnSyncContext方法,否者調用CancellationCallbackCoreWork方法【CancellationCallbackCoreWork_OnSyncContext裏面也是調用它】,CancellationCallbackCoreWork方法是調用CancellationCallbackInfo的ExecuteCallback。函數

CancellationTokenSource有兩個CreateLinkedTokenSource方法【能夠理解爲建立於當前的CreateLinkedTokenSource相關聯的CreateLinkedTokenSource】,期主要實現是CancellationToken的Register方法。this

public struct CancellationToken
{
    private CancellationTokenSource m_source;
    internal CancellationToken(CancellationTokenSource source)
    {
        m_source = source;
    }
    public CancellationToken(bool canceled) :this()
    {
        if(canceled)
            m_source = CancellationTokenSource.InternalGetStaticSource(canceled);
    } 
    
   public CancellationTokenRegistration Register(Action callback)
    {
        if (callback == null)
            throw new ArgumentNullException("callback");
        
        return Register(s_ActionToActionObjShunt,callback,false,true);
    }
    
    public CancellationTokenRegistration Register(Action callback, bool useSynchronizationContext)
    {
        if (callback == null)
            throw new ArgumentNullException("callback");
        
        return Register(s_ActionToActionObjShunt,callback,useSynchronizationContext,true);
    }
    
   public CancellationTokenRegistration Register(Action<Object> callback, Object state)
    {
        if (callback == null)
            throw new ArgumentNullException("callback");

        return Register(callback,state,false,true);
    }
    
   /// Registers a delegate that will be called when this CancellationToken is canceled.
    public CancellationTokenRegistration Register(Action<Object> callback, Object state, bool useSynchronizationContext)
    {
        return Register(callback,state,useSynchronizationContext,true);
    }
    
   private CancellationTokenRegistration Register(Action<Object> callback, Object state, bool useSynchronizationContext, bool useExecutionContext)
    {
        StackCrawlMark stackMark = StackCrawlMark.LookForMyCaller;

        if (callback == null)
            throw new ArgumentNullException("callback");

        if (CanBeCanceled == false)
        {
            return new CancellationTokenRegistration(); // nothing to do for tokens than can never reach the canceled state. Give them a dummy registration.
        }

        SynchronizationContext capturedSyncContext = null;
        ExecutionContext capturedExecutionContext = null;
        if (!IsCancellationRequested)
        {
            if (useSynchronizationContext)
                capturedSyncContext = SynchronizationContext.Current;
            if (useExecutionContext)
                capturedExecutionContext = ExecutionContext.Capture(ref stackMark, ExecutionContext.CaptureOptions.OptimizeDefaultCase); 
        }

        // Register the callback with the source.
        return m_source.InternalRegister(callback, state, capturedSyncContext, capturedExecutionContext);
    }
    
    private readonly static Action<Object> s_ActionToActionObjShunt = new Action<Object>(ActionToActionObjShunt);
    private static void ActionToActionObjShunt(object obj)
    {
        Action action = obj as Action;
        Contract.Assert(action != null, "Expected an Action here");
       action();
    }
        
    public static CancellationToken None
    {
        get { return default(CancellationToken); }
    }
    public bool IsCancellationRequested 
    {
        get
        {
            return m_source != null && m_source.IsCancellationRequested;
        }
    }
    
    public bool CanBeCanceled
    {
        get
        {
            return m_source != null && m_source.CanBeCanceled;
        }
    }
    public void ThrowIfCancellationRequested()
    {
        if (IsCancellationRequested) 
            ThrowOperationCanceledException();
    }
    private void ThrowOperationCanceledException()
    {
        throw new OperationCanceledException(Environment.GetResourceString("OperationCanceled"), this);
    }
}

CancellationToken的不少屬性都是來源於CancellationTokenSource的屬性,CancellationToken的主要方法 Register 也是嗲用CancellationTokenSource的InternalRegister方法。InternalRegister方法檢查當前是否發起了Cancel【IsCancellationRequested】,若是是直接調用回調方法callback(stateForCallback);,否者把回調方法包裝成CancellationCallbackInfo實例,而後添加到m_registeredCallbacksLists對象中,而後在返回CancellationTokenRegistration實例。spa

    internal class CancellationCallbackInfo
    {
        internal readonly Action<object> Callback;
        internal readonly object StateForCallback;
        internal readonly SynchronizationContext TargetSyncContext;
        internal readonly ExecutionContext TargetExecutionContext;
        internal readonly CancellationTokenSource CancellationTokenSource;

        internal CancellationCallbackInfo(Action<object> callback, object stateForCallback, SynchronizationContext targetSyncContext, ExecutionContext targetExecutionContext,CancellationTokenSource cancellationTokenSource)
        {
            Callback = callback;
            StateForCallback = stateForCallback;
            TargetSyncContext = targetSyncContext;
            TargetExecutionContext = targetExecutionContext;
            CancellationTokenSource = cancellationTokenSource;
        }

        private static ContextCallback s_executionContextCallback;
        internal void ExecuteCallback()
        {
            if (TargetExecutionContext != null)
            {
                var callback = s_executionContextCallback;
                if (callback == null) s_executionContextCallback = callback = new ContextCallback(ExecutionContextCallback);
                
                ExecutionContext.Run(TargetExecutionContext, callback, this);
            }
            else
            {
               ExecutionContextCallback(this);
            }
        }

        private static void ExecutionContextCallback(object obj)
        {
            CancellationCallbackInfo callbackInfo = obj as CancellationCallbackInfo;
            Contract.Assert(callbackInfo != null);
        callbackInfo.Callback(callbackInfo.StateForCallback);
        }
    }
    
    internal class SparselyPopulatedArray<T> where T : class
    {
        private readonly SparselyPopulatedArrayFragment<T> m_head;
        private volatile SparselyPopulatedArrayFragment<T> m_tail;
        internal SparselyPopulatedArray(int initialSize)
        {
            m_head = m_tail = new SparselyPopulatedArrayFragment<T>(initialSize);
        }

        internal SparselyPopulatedArrayFragment<T> Tail
        {
            get { return m_tail; }
        }

        internal SparselyPopulatedArrayAddInfo<T> Add(T element)
        {
            while (true)
            {
                // Get the tail, and ensure it's up to date.
                SparselyPopulatedArrayFragment<T> tail = m_tail;
                while (tail.m_next != null)
                    m_tail = (tail = tail.m_next);

                // Search for a free index, starting from the tail.
                SparselyPopulatedArrayFragment<T> curr = tail;
                while (curr != null)
                {
                    const int RE_SEARCH_THRESHOLD = -10; // Every 10 skips, force a search.
                    if (curr.m_freeCount < 1)
                        --curr.m_freeCount;

                    if (curr.m_freeCount > 0 || curr.m_freeCount < RE_SEARCH_THRESHOLD)
                    {
                        int c = curr.Length;
                        int start = ((c - curr.m_freeCount) % c);
                        if (start < 0)
                        {
                            start = 0;
                            curr.m_freeCount--; // Too many free elements; fix up.
                        }
                        Contract.Assert(start >= 0 && start < c, "start is outside of bounds");

                        // Now walk the array until we find a free slot (or reach the end).
                        for (int i = 0; i < c; i++)
                        {
                            // If the slot is null, try to CAS our element into it.
                            int tryIndex = (start + i) % c;
                            Contract.Assert(tryIndex >= 0 && tryIndex < curr.m_elements.Length, "tryIndex is outside of bounds");
                            
                            if (curr.m_elements[tryIndex] == null && Interlocked.CompareExchange(ref curr.m_elements[tryIndex], element, null) == null)
                            {
                                int newFreeCount = curr.m_freeCount - 1;
                                curr.m_freeCount = newFreeCount > 0 ? newFreeCount : 0;
                                return new SparselyPopulatedArrayAddInfo<T>(curr, tryIndex);
                            }
                        }
                    }

                    curr = curr.m_prev;
                }

                // If we got here, we need to add a new chunk to the tail and try again.
                SparselyPopulatedArrayFragment<T> newTail = new SparselyPopulatedArrayFragment<T>(
                    tail.m_elements.Length == 4096 ? 4096 : tail.m_elements.Length * 2, tail);
                if (Interlocked.CompareExchange(ref tail.m_next, newTail, null) == null)
                {
                    m_tail = newTail;
                }
            }
        }
    }
    
    internal struct SparselyPopulatedArrayAddInfo<T> where T : class
    {
        private SparselyPopulatedArrayFragment<T> m_source;
        private int m_index;

        internal SparselyPopulatedArrayAddInfo(SparselyPopulatedArrayFragment<T> source, int index)
        {
            Contract.Assert(source != null);
            Contract.Assert(index >= 0 && index < source.Length);
            m_source = source;
            m_index = index;
        }

        internal SparselyPopulatedArrayFragment<T> Source
        {
            get { return m_source; }
        }

        internal int Index
        {
            get { return m_index; }
        }
    }
    
    internal class SparselyPopulatedArrayFragment<T> where T : class
    {
        internal readonly T[] m_elements; // The contents, sparsely populated (with nulls).
        internal volatile int m_freeCount; // A hint of the number of free elements.
        internal volatile SparselyPopulatedArrayFragment<T> m_next; // The next fragment in the chain.
        internal volatile SparselyPopulatedArrayFragment<T> m_prev; // The previous fragment in the chain.

        internal SparselyPopulatedArrayFragment(int size) : this(size, null)
        {
        }

        internal SparselyPopulatedArrayFragment(int size, SparselyPopulatedArrayFragment<T> prev)
        {
            m_elements = new T[size];
            m_freeCount = size;
            m_prev = prev;
        }

        internal T this[int index]
        {
            get { return Volatile.Read<T>(ref m_elements[index]); }
        }

        internal int Length
        {
            get { return m_elements.Length; }
        }

        internal SparselyPopulatedArrayFragment<T> Prev
        {
            get { return m_prev; }
        }

        internal T SafeAtomicRemove(int index, T expectedElement)
        {
          T prevailingValue = Interlocked.CompareExchange(ref m_elements[index], null, expectedElement);
            if (prevailingValue != null) 
                ++m_freeCount;
            return prevailingValue;
        }
    }

回頭看CancellationCallbackInfo的實現也很簡單。code

相關文章
相關標籤/搜索