使用Unity攔截一個返回Task的方法

目標

主要是想爲服務方法注入公用的異常處理代碼,從而使得業務代碼簡潔。本人使用Unity.Interception主鍵來達到這個目標。因爲但願默認就執行攔截,因此使用了虛方法攔截器。要實現攔截,須要實現一個攔截處理類,此類型要求實現接口ICallHandler,例如:express

public class ServiceHandler : ICallHandler
    {
        public IMethodReturn Invoke(IMethodInvocation input,
            GetNextHandlerDelegate getNext)
        {
            Trace.WriteLine("開始調用");
            IMethodReturn result;
            result = getNext()(input, getNext);
            if (result.Exception != null)
            {
                Trace.WriteLine("發生了異常");
            }
            Trace.WriteLine("結束調用");
            return result;
        }

        public int Order { get; set; }
    }
View Code

此外還定義了servicebase基類。緩存

    public class ServiceBase
    {
    }
    

該類型沒有任何方法,這裏添加一個繼承與該類型的子類,同時添加一個方法(注意,因爲使用虛方法攔截,須要攔截的方法必須標記爲virtual)異步

    public class FakeService : ServiceBase
    {
        public virtual int GetInt()
        {
            throw new Exception("");
            return 100;
        }
    }
View Code

使用單元測試來調用這個方法,獲得的結果:async

以上是使用Unity在方法調用先後注入的例子,對於同步方法而言並不存在問題。因爲.NET 4.5引入了async和await,異步方法變得常見,使用傳統的方式注入變得行不通。其實,不單單是async方法,全部awaitable的方法都存在這個問題。ide

緣由很簡單,對於同步方法(不可等待的方法)而言,調用先後就是內部執行調用的先後,而對於返回Task類型的方法而言,調用結束後,異步操做並未結束,因此即便異步操做發生了異常,也沒法被捕捉。這裏使用一個異步方法進行實驗:性能

由結果可見,攔截先後並無發生異常,異常時在對Task對象等待的時候發生的。單元測試

方案

既然知道了爲什麼沒法攔截,那麼就很容易得出方案:將攔截的範圍延伸到Task方法執行完畢以後的點。首先,拿不帶返回值的Task實驗。對於Task而言,咱們只關心Task結束後的操做,而咱們又不須要爲其返回一個對象。因此,咱們必定可使用一個參數簽名爲Action<Task>的ContinueWith來達到目標,修改Handler的實現以下:測試

public IMethodReturn Invoke(IMethodInvocation input,
            GetNextHandlerDelegate getNext)
        {
            Trace.WriteLine("開始調用");
            var result = getNext()(input, getNext);

            if (result.ReturnValue is Task)
            {
                var task = result.ReturnValue as Task;
                var continued = task.ContinueWith(t =>
                {
                    if (t.IsFaulted)
                    {
                        Trace.WriteLine("發生了異常");
                    }
                    Trace.WriteLine("結束調用");
                });
                return input.CreateMethodReturn(continued);
            }

            if (result.Exception != null)
            {
                Trace.WriteLine("發生了異常");
            }
            Trace.WriteLine("結束調用");
            return result;
        }
View Code

對應的結果以下:spa

對比先後兩次的結果,能夠發現咱們達成了目標。code

困境

對於單純的Task,可使用單純的ContinueWith來解決,然而對於帶返回值的Task<T>,就沒那麼簡單了。這是由於咱們要保證ContinueWith以後,返回的Task的類型和目標方法的返回值類型一直,例如,若是方法要求返回一個Task<int>,那麼,咱們調用的ContinueWith方法必然是要求參數類型爲Func<Task<int>,int>的重載。

事實上,理論上說,咱們能夠把輸入參數限定爲基類Task,從而調用Func<Task,dynamic>這個類型重載,而後經過動態類型適配返回值...然而經過這種方式調用的話,ContinueWith返回的對象類型爲ContinuedXXXTask,一個運行時類型,是沒法和方法簽名上的類型匹配的。

若是放棄這種通用的方式,咱們還能夠對Task<T>的泛型參數T進行判斷,而後轉型,而後再調用ContinueWith,不過這樣的話只能處理已知類型。然而這種最原始的方式給了咱們思路:動態構造。

動態構造

上文中提到的「原始」方法的問題所在是要求枚舉各類類型,咱們知道是不可能的。那麼動態構造的思路就是碰到新類型的時候,咱們就」添加一個If「,即加上一種處理方式。這就很像是在運行期間寫代碼了,而解決這類問題,咱們可使用ExpressionTree。

也就是說,咱們須要在運行時根據不一樣的狀況調用不一樣重載的ContinueWith。這裏分兩步:一,找到合適的方法重載;二,構造合適的參數。

先找方法,對於Task<T>咱們須要調用的參數爲Func<,>的重載,對於Task則是Action<>:

 private MethodInfo FindContinueWith(Type taskType, bool hasReturn)
        {
            var methods = taskType.GetMethods().Where(i => i.Name == "ContinueWith").ToList();
            if (hasReturn)
            {
                var returnType = taskType.GetGenericArguments().First();
                return methods.Where(i =>
                {
                    var pars = i.GetParameters().ToList();
                    return pars.Count == 1 && pars.First().ParameterType.Name.StartsWith("F");
                }).First().MakeGenericMethod(returnType);
            }
            return methods.Where(i =>
            {
                var pars = i.GetParameters().ToList();
                return pars.Count == 1
                       && pars.First().ParameterType.Name.StartsWith("A")
                       && pars.First().ParameterType.IsGenericType;
            })
            .First();
        }
View Code

而後生成參數,這裏先構造一個Expression:

private Expression MakeContinueExpression(Type taskType, Expression actionExp, bool hasReturn)
        {
            ParameterExpression taskParam;
            Expression handelTaskExp;

            if (!hasReturn)
            {
                taskParam = Expression.Parameter(typeof (Task));
                //當Task不帶返回值的時候,使用(t)=>action(t)
                handelTaskExp = Expression.Invoke(actionExp, taskParam);
                return Expression.Lambda(handelTaskExp, taskParam);
            }

            taskParam = Expression.Parameter(taskType);
            handelTaskExp = Expression.Invoke(actionExp, taskParam);
            
            var returnType = taskParam.Type.GetGenericArguments()[0];
            var defaultResult = Expression.Default(returnType);
            var returnTarget = Expression.Label(returnType);
            var returnLable = Expression.Label(returnTarget, defaultResult);
            var paramResult = Expression.PropertyOrField(taskParam, "Result");
            var returnExp = Expression.Return(returnTarget, paramResult);
            //當Task帶返回值的時候,使用(t)=>{action(t);return t.Result;}
            var blockExp = Expression.Block(handelTaskExp, returnExp, returnLable);
            var expression = Expression.Lambda(blockExp, taskParam);
            return expression;
        }
View Code

參數中的Action表明的是咱們須要額外作的事情,這樣作的好處是,對於一個指定的Task<T>,不管你想額外作何時,只須要編譯一次ExpressionTree,這有利於ExpressionTree緩存從而提升性能。

最後就是編譯ExpressionTree生成一個委託:

private Func<Task, Action<Task>, object> MakeContinueTaskFactory(Type taskType, bool hasReturn)
        {
            var key = taskType.FullName;
            return ConcurrentDic.GetOrAdd(key, k =>
            {
                var actionParam = Expression.Parameter(typeof (Action<Task>));
                var continueParam = MakeContinueExpression(taskType, actionParam, hasReturn);
                var taskParam = Expression.Parameter(typeof (Task));
                var taskTexp = Expression.Convert(taskParam, taskType);
                var mehtodInfo = FindContinueWith(taskType, hasReturn);
                var callExp = Expression.Call(taskTexp, mehtodInfo, continueParam);
                var lambda = Expression.Lambda<Func<Task, Action<Task>, object>>(callExp, taskParam, actionParam);
                return lambda.Compile();
            });
        }
View Code

這個方法返回一個委託,該委託接受一個Task,和一個Action,執行後返回另一個Task(ContinueWith)。

這裏是完整的代碼:

using System.Linq.Expressions.Caching;
using System.Reflection;
using System.Threading.Tasks;

// ReSharper disable once CheckNamespace
namespace System.Linq.Expressions
{
    public class TaskInjector : CacheBlock<string, Func<Task, Action<Task>, object>>
    {
        /// <summary>
        /// 獲取Task的ContinueWith方法
        /// </summary>
        /// <param name="taskType"></param>
        /// <param name="hasReturn"></param>
        /// <returns></returns>
        private MethodInfo FindContinueWith(Type taskType, bool hasReturn)
        {
            var methods = taskType.GetMethods().Where(i => i.Name == "ContinueWith").ToList();
            if (hasReturn)
            {
                var returnType = taskType.GetGenericArguments().First();
                return methods.Where(i =>
                {
                    var pars = i.GetParameters().ToList();
                    return pars.Count == 1 && pars.First().ParameterType.Name.StartsWith("F");
                }).First().MakeGenericMethod(returnType);
            }
            return methods.Where(i =>
            {
                var pars = i.GetParameters().ToList();
                return pars.Count == 1
                       && pars.First().ParameterType.Name.StartsWith("A")
                       && pars.First().ParameterType.IsGenericType;
            })
            .First();
        }

        /// <summary>
        /// 針對不一樣Task生成不一樣的ContinueWith委託
        /// </summary>
        /// <param name="taskType"></param>
        /// <param name="actionExp"></param>
        /// <param name="hasReturn"></param>
        /// <returns></returns>
        private Expression MakeContinueExpression(Type taskType, Expression actionExp, bool hasReturn)
        {
            ParameterExpression taskParam;
            Expression handelTaskExp;

            if (!hasReturn)
            {
                taskParam = Expression.Parameter(typeof (Task));
                //當Task不帶返回值的時候,使用(t)=>action(t)
                handelTaskExp = Expression.Invoke(actionExp, taskParam);
                return Expression.Lambda(handelTaskExp, taskParam);
            }

            taskParam = Expression.Parameter(taskType);
            handelTaskExp = Expression.Invoke(actionExp, taskParam);
            
            var returnType = taskParam.Type.GetGenericArguments()[0];
            var defaultResult = Expression.Default(returnType);
            var returnTarget = Expression.Label(returnType);
            var returnLable = Expression.Label(returnTarget, defaultResult);
            var paramResult = Expression.PropertyOrField(taskParam, "Result");
            var returnExp = Expression.Return(returnTarget, paramResult);
            //當Task帶返回值的時候,使用(t)=>{action(t);return t.Result;}
            var blockExp = Expression.Block(handelTaskExp, returnExp, returnLable);
            var expression = Expression.Lambda(blockExp, taskParam);
            return expression;
        }

        /// <summary>
        /// 爲指定的Task類型編譯一個ContinueWith的生成器
        /// </summary>
        /// <param name="taskType"></param>
        /// <param name="hasReturn"></param>
        /// <returns></returns>
        private Func<Task, Action<Task>, object> MakeContinueTaskFactory(Type taskType, bool hasReturn)
        {
            var key = taskType.FullName;
            return ConcurrentDic.GetOrAdd(key, k =>
            {
                var actionParam = Expression.Parameter(typeof (Action<Task>));
                var continueParam = MakeContinueExpression(taskType, actionParam, hasReturn);
                var taskParam = Expression.Parameter(typeof (Task));
                var taskTexp = Expression.Convert(taskParam, taskType);
                var mehtodInfo = FindContinueWith(taskType, hasReturn);
                var callExp = Expression.Call(taskTexp, mehtodInfo, continueParam);
                var lambda = Expression.Lambda<Func<Task, Action<Task>, object>>(callExp, taskParam, actionParam);
                return lambda.Compile();
            });
        }

        /// <summary>
        /// 爲Task類型的對象注入代碼
        /// </summary>
        /// <param name="task"></param>
        /// <param name="action"></param>
        /// <returns></returns>
        public object Inject(Task task, Action<Task> action)
        {
            var runtimeType = task.GetType();
            var hasReturn = runtimeType.IsGenericType && runtimeType.GetProperty("Result").PropertyType.Name != "VoidTaskResult";
            var func = MakeContinueTaskFactory(runtimeType, hasReturn);
            return func(task, action);
        }

        public static TaskInjector Instance = new TaskInjector();
    }
}
View Code

以及一個輔助的緩存類:

using System.Linq.Expressions.Caching;
using System.Reflection;
using System.Threading.Tasks;

// ReSharper disable once CheckNamespace
namespace System.Linq.Expressions
{
    public class TaskInjector : CacheBlock<string, Func<Task, Action<Task>, object>>
    {
        /// <summary>
        /// 獲取Task的ContinueWith方法
        /// </summary>
        /// <param name="taskType"></param>
        /// <param name="hasReturn"></param>
        /// <returns></returns>
        private MethodInfo FindContinueWith(Type taskType, bool hasReturn)
        {
            var methods = taskType.GetMethods().Where(i => i.Name == "ContinueWith").ToList();
            if (hasReturn)
            {
                var returnType = taskType.GetGenericArguments().First();
                return methods.Where(i =>
                {
                    var pars = i.GetParameters().ToList();
                    return pars.Count == 1 && pars.First().ParameterType.Name.StartsWith("F");
                }).First().MakeGenericMethod(returnType);
            }
            return methods.Where(i =>
            {
                var pars = i.GetParameters().ToList();
                return pars.Count == 1
                       && pars.First().ParameterType.Name.StartsWith("A")
                       && pars.First().ParameterType.IsGenericType;
            })
            .First();
        }

        /// <summary>
        /// 針對不一樣Task生成不一樣的ContinueWith委託
        /// </summary>
        /// <param name="taskType"></param>
        /// <param name="actionExp"></param>
        /// <param name="hasReturn"></param>
        /// <returns></returns>
        private Expression MakeContinueExpression(Type taskType, Expression actionExp, bool hasReturn)
        {
            ParameterExpression taskParam;
            Expression handelTaskExp;

            if (!hasReturn)
            {
                taskParam = Expression.Parameter(typeof (Task));
                //當Task不帶返回值的時候,使用(t)=>action(t)
                handelTaskExp = Expression.Invoke(actionExp, taskParam);
                return Expression.Lambda(handelTaskExp, taskParam);
            }

            taskParam = Expression.Parameter(taskType);
            handelTaskExp = Expression.Invoke(actionExp, taskParam);
            
            var returnType = taskParam.Type.GetGenericArguments()[0];
            var defaultResult = Expression.Default(returnType);
            var returnTarget = Expression.Label(returnType);
            var returnLable = Expression.Label(returnTarget, defaultResult);
            var paramResult = Expression.PropertyOrField(taskParam, "Result");
            var returnExp = Expression.Return(returnTarget, paramResult);
            //當Task帶返回值的時候,使用(t)=>{action(t);return t.Result;}
            var blockExp = Expression.Block(handelTaskExp, returnExp, returnLable);
            var expression = Expression.Lambda(blockExp, taskParam);
            return expression;
        }

        /// <summary>
        /// 爲指定的Task類型編譯一個ContinueWith的生成器
        /// </summary>
        /// <param name="taskType"></param>
        /// <param name="hasReturn"></param>
        /// <returns></returns>
        private Func<Task, Action<Task>, object> MakeContinueTaskFactory(Type taskType, bool hasReturn)
        {
            var key = taskType.FullName;
            return ConcurrentDic.GetOrAdd(key, k =>
            {
                var actionParam = Expression.Parameter(typeof (Action<Task>));
                var continueParam = MakeContinueExpression(taskType, actionParam, hasReturn);
                var taskParam = Expression.Parameter(typeof (Task));
                var taskTexp = Expression.Convert(taskParam, taskType);
                var mehtodInfo = FindContinueWith(taskType, hasReturn);
                var callExp = Expression.Call(taskTexp, mehtodInfo, continueParam);
                var lambda = Expression.Lambda<Func<Task, Action<Task>, object>>(callExp, taskParam, actionParam);
                return lambda.Compile();
            });
        }

        /// <summary>
        /// 爲Task類型的對象注入代碼
        /// </summary>
        /// <param name="task"></param>
        /// <param name="action"></param>
        /// <returns></returns>
        public object Inject(Task task, Action<Task> action)
        {
            var runtimeType = task.GetType();
            var hasReturn = runtimeType.IsGenericType && runtimeType.GetProperty("Result").PropertyType.Name != "VoidTaskResult";
            var func = MakeContinueTaskFactory(runtimeType, hasReturn);
            return func(task, action);
        }

        public static TaskInjector Instance = new TaskInjector();
    }
}
View Code

此時,咱們就能夠繼續改造Handler實現: 

 public IMethodReturn Invoke(IMethodInvocation input,
            GetNextHandlerDelegate getNext)
        {
            Trace.WriteLine("開始調用");
            var result = getNext()(input, getNext);

            if (result.ReturnValue is Task)
            {
                var task = result.ReturnValue as Task;
                var continued = TaskInjector.Instance.Inject(task, (t) =>
                {
                    if (t.IsFaulted)
                    {
                        Trace.WriteLine("發生了異常");
                    }
                    Trace.WriteLine("偷看值:" + PropertyFieldLoader.Instance.Load<object>(task, task.GetType(), "Result"));
                    Trace.WriteLine("結束調用");
                });
                return input.CreateMethodReturn(continued);
            }

            if (result.Exception != null)
            {
                Trace.WriteLine("發生了異常");
            }
            Trace.WriteLine("結束調用");
            return result;
        }
View Code

對應的結果以下:

而對於一個返回Task<T>的方法:

public virtual async Task<int> GetIntAsync()
        {
            return await Task.FromResult(1000);
        }

結果以下:

相關文章
相關標籤/搜索