記一次被yield return坑的歷程。

事情的通過是這樣的:函數

我用C#寫了一個很簡單的一個經過迭代生成序列的函數。單元測試

public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length)
{
    Checker.NullCheck(nameof(f), f);    
    Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);

    var current = initVal;
    while (--length >= 0)
    {
        yield return (current = f(current));
    }
}

其中NullCheck用於檢查參數是否爲null,若是是則拋出ArgumentNullException異常。測試

對應的,我寫了以下單元測試代碼去檢測這個異常。this

public void TestIterate()
{
    Func<int, int> f = null;
    Assert.Throws<ArgumentNullException>(() => f.Iterate(1, 7));
    
    // Other tests
}

可是,這個測試出乎意料的fail了。spa

一開始,我覺得是NullCheck函數的問題,可我把NullCheck直接換成了if語句,仍是通不過。調試

後來我在Iterate函數下斷點並調試。結果調試器根本沒有停在斷點上,直接運行完了測試。code

我覺得是我測試的方法不對,因此我不斷的修改測試代碼,甚至還一度覺得是.NET的Unit Tests出了bug。對象

最終,我在這個測試代碼發現了問題:blog

Assert.Throws<ArgumentNullException>(() =>
{
    var seq = f.Iterate(1, 7);
    foreach (int ele in seq)
        Console.WriteLine(ele);
});

當我調試這個測試時,程序停在了我以前在Iterate函數上下的斷點。工作流

因而,我在 var seq = f.Iterate(1, 7); 上下斷點,並逐步運行。這時我發現,當程序運行到 var seq = f.Iterate(1, 7); 時並不會進入Iterate函數;而是當程序運行到foreach語句後才進入。

這就要涉及到yield return的具體工做流程。當函數代碼中出現yield return,調用這個函數會直接返回一個IEnumerable<T>或IEnumerator<T>對象,並不會執行函數體的任何代碼。這些代碼都被封裝到了返回對象的內部。它們會在你開始枚舉的時候開始執行。

所以,上面兩個Check並不會在函數調用時執行,而是在當你開始foreach的時候才執行。

這並非我想要的結果。我但願在調用函數時就檢查參數合法性,若是不合法便直接拋出異常。

解決這個問題有兩種途徑,一是把它拆成兩個函數:

public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length)
{
    Checker.NullCheck(nameof(f), f);    
    Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);
            
    return IterateWithoutCheck(f, initVal, length);
}

private static IEnumerable<T> IterateWithoutCheck<T>(this Func<T, T> f, T initVal, int length)
{
    var current = initVal;
    while (--length >= 0)
    {
        yield return (current = f(current));
    }
}

或者,你也能夠將這個函數包裝成一個類。

class FunctionIterator<T> : IEnumerable<T>
{
    private readonly Func<T, T> f;
    private readonly T initVal;
    private readonly int length;
        
    public FunctionIterator(Func<T, T> f, T initVal, int length)
    {
        Checker.NullCheck(nameof(f), f);
        Checker.RangeCheck(nameof(length), length, 0, int.MaxValue);
            
        this.f = f;
        this.initVal = initVal;
        this.length = length;
    }

    public IEnumerator<T> GetEnumerator()
    {
        T current = initVal;

        for (int i = 0; i < length; ++i)
            yield return (current = f(current));
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return GetEnumerator();
    }
}
相關文章
相關標籤/搜索