blob: fd15f9de4f0f9df211ddd821ba857d0efbe34eed (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
#if NET_4_5
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Threading;
namespace NUnit.Framework.Internal
{
internal abstract class AsyncInvocationRegion : IDisposable
{
private static readonly Type AsyncStateMachineAttribute = Type.GetType("System.Runtime.CompilerServices.AsyncStateMachineAttribute");
//private static readonly MethodInfo PreserveStackTraceMethod = typeof(Exception).GetMethod("InternalPreserveStackTrace", BindingFlags.Instance | BindingFlags.NonPublic);
//private static readonly Action<Exception> PreserveStackTrace;
static AsyncInvocationRegion()
{
//PreserveStackTrace = (Action<Exception>)Delegate.CreateDelegate(typeof(Action<Exception>), PreserveStackTraceMethod);
}
private AsyncInvocationRegion()
{
}
public static AsyncInvocationRegion Create(Delegate @delegate)
{
return Create(@delegate.Method);
}
public static AsyncInvocationRegion Create(MethodInfo method)
{
if (!IsAsyncOperation(method))
throw new InvalidOperationException(@"Either asynchronous support is not available or an attempt
at wrapping a non-async method invocation in an async region was done");
if (method.ReturnType == typeof(void))
return new AsyncVoidInvocationRegion();
return new AsyncTaskInvocationRegion();
}
public static bool IsAsyncOperation(MethodInfo method)
{
return AsyncStateMachineAttribute != null && method.IsDefined(AsyncStateMachineAttribute, false);
}
public static bool IsAsyncOperation(Delegate @delegate)
{
return IsAsyncOperation(@delegate.Method);
}
/// <summary>
/// Waits for pending asynchronous operations to complete, if appropriate,
/// and returns a proper result of the invocation by unwrapping task results
/// </summary>
/// <param name="invocationResult">The raw result of the method invocation</param>
/// <returns>The unwrapped result, if necessary</returns>
public abstract object WaitForPendingOperationsToComplete(object invocationResult);
public virtual void Dispose()
{ }
private class AsyncVoidInvocationRegion : AsyncInvocationRegion
{
private readonly SynchronizationContext _previousContext;
private readonly AsyncSynchronizationContext _currentContext;
public AsyncVoidInvocationRegion()
{
_previousContext = SynchronizationContext.Current;
_currentContext = new AsyncSynchronizationContext();
SynchronizationContext.SetSynchronizationContext(_currentContext);
}
public override void Dispose()
{
SynchronizationContext.SetSynchronizationContext(_previousContext);
}
public override object WaitForPendingOperationsToComplete(object invocationResult)
{
_currentContext.WaitForPendingOperationsToComplete();
return invocationResult;
}
}
private class AsyncTaskInvocationRegion : AsyncInvocationRegion
{
private const string TaskWaitMethod = "Wait";
private const string TaskResultProperty = "Result";
private const string SystemAggregateException = "System.AggregateException";
private const string InnerExceptionsProperty = "InnerExceptions";
private const BindingFlags TaskResultPropertyBindingFlags = BindingFlags.GetProperty | BindingFlags.Instance | BindingFlags.Public;
public override object WaitForPendingOperationsToComplete(object invocationResult)
{
try
{
invocationResult.GetType().GetMethod(TaskWaitMethod, new Type[0]).Invoke(invocationResult, null);
}
catch (TargetInvocationException e)
{
IList<Exception> innerExceptions = GetAllExceptions(e.InnerException);
//PreserveStackTrace(innerExceptions[0]);
throw innerExceptions[0];
}
PropertyInfo taskResultProperty = invocationResult.GetType().GetProperty(TaskResultProperty, TaskResultPropertyBindingFlags);
return taskResultProperty != null ? taskResultProperty.GetValue(invocationResult, null) : invocationResult;
}
private static IList<Exception> GetAllExceptions(Exception exception)
{
if (SystemAggregateException.Equals(exception.GetType().FullName))
return (IList<Exception>)exception.GetType().GetProperty(InnerExceptionsProperty).GetValue(exception, null);
return new Exception[] { exception };
}
}
}
}
#endif
|