事情的通過是這樣的:函數
我用C#寫了一個很簡單的一個經過迭代生成序列的函數。單元測試
public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length) { Checker.NullCheck(nameof(f), f); Checker.RangeCheck(nameof(length), length, 0, int.MaxValue); var current = initVal; while (--length >= 0) { yield return (current = f(current)); } }
其中NullCheck用於檢查參數是否爲null,若是是則拋出ArgumentNullException異常。測試
對應的,我寫了以下單元測試代碼去檢測這個異常。this
public void TestIterate() { Func<int, int> f = null; Assert.Throws<ArgumentNullException>(() => f.Iterate(1, 7)); // Other tests }
可是,這個測試出乎意料的fail了。spa
一開始,我覺得是NullCheck函數的問題,可我把NullCheck直接換成了if語句,仍是通不過。調試
後來我在Iterate函數下斷點並調試。結果調試器根本沒有停在斷點上,直接運行完了測試。code
我覺得是我測試的方法不對,因此我不斷的修改測試代碼,甚至還一度覺得是.NET的Unit Tests出了bug。對象
最終,我在這個測試代碼發現了問題:blog
Assert.Throws<ArgumentNullException>(() => { var seq = f.Iterate(1, 7); foreach (int ele in seq) Console.WriteLine(ele); });
當我調試這個測試時,程序停在了我以前在Iterate函數上下的斷點。工作流
因而,我在 var seq = f.Iterate(1, 7); 上下斷點,並逐步運行。這時我發現,當程序運行到 var seq = f.Iterate(1, 7); 時並不會進入Iterate函數;而是當程序運行到foreach語句後才進入。
這就要涉及到yield return的具體工做流程。當函數代碼中出現yield return,調用這個函數會直接返回一個IEnumerable<T>或IEnumerator<T>對象,並不會執行函數體的任何代碼。這些代碼都被封裝到了返回對象的內部。它們會在你開始枚舉的時候開始執行。
所以,上面兩個Check並不會在函數調用時執行,而是在當你開始foreach的時候才執行。
這並非我想要的結果。我但願在調用函數時就檢查參數合法性,若是不合法便直接拋出異常。
解決這個問題有兩種途徑,一是把它拆成兩個函數:
public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length) { Checker.NullCheck(nameof(f), f); Checker.RangeCheck(nameof(length), length, 0, int.MaxValue); return IterateWithoutCheck(f, initVal, length); } private static IEnumerable<T> IterateWithoutCheck<T>(this Func<T, T> f, T initVal, int length) { var current = initVal; while (--length >= 0) { yield return (current = f(current)); } }
或者,你也能夠將這個函數包裝成一個類。
class FunctionIterator<T> : IEnumerable<T> { private readonly Func<T, T> f; private readonly T initVal; private readonly int length; public FunctionIterator(Func<T, T> f, T initVal, int length) { Checker.NullCheck(nameof(f), f); Checker.RangeCheck(nameof(length), length, 0, int.MaxValue); this.f = f; this.initVal = initVal; this.length = length; } public IEnumerator<T> GetEnumerator() { T current = initVal; for (int i = 0; i < length; ++i) yield return (current = f(current)); } System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { return GetEnumerator(); } }