diff options
author | Morgan Brown <morganbr@users.noreply.github.com> | 2018-09-07 07:53:46 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-07 07:53:46 +0300 |
commit | 296c7356f3df796b62d965bc0d9a315390ef4fcf (patch) | |
tree | 81bcdc0866f5191c92dc478e7783f6410f63698d | |
parent | 9ca75bb786f8abcfc022a7d2f14b4fc2d8e36301 (diff) |
Fix constrained calls and marshaling (#6275)
Codegen fixes required to make Console.WriteLine work on WebAssembly (when combined with #5987 and a matching CoreFX build):
The 'this' pointer for reference types is a byref for constrained virtual calls and needs to be dereferenced.
Enable P/Invoke marshaling. This isn't easy to test directly, but we'll notice the effect as we try to use framework code that P/Invokes (in Console.WriteLine, there's SafeHandle marshaling).
-rw-r--r-- | src/BuildIntegration/Microsoft.NETCore.Native.targets | 2 | ||||
-rw-r--r-- | src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs | 76 | ||||
-rw-r--r-- | tests/src/Simple/HelloWasm/Program.cs | 205 |
3 files changed, 260 insertions, 23 deletions
diff --git a/src/BuildIntegration/Microsoft.NETCore.Native.targets b/src/BuildIntegration/Microsoft.NETCore.Native.targets index 79f3a61b7..66af195d0 100644 --- a/src/BuildIntegration/Microsoft.NETCore.Native.targets +++ b/src/BuildIntegration/Microsoft.NETCore.Native.targets @@ -250,7 +250,7 @@ See the LICENSE file in the project root for more information. <PropertyGroup> <EmccArgs>"$(NativeObject)" -o "$(NativeBinary)" -s WASM=1 -s ALLOW_MEMORY_GROWTH=1 --emrun </EmccArgs> - <EmccArgs Condition="'$(Platform)'=='wasm'">$(EmccArgs) "$(IlcPath)/sdk/libPortableRuntime.bc" "$(IlcPath)/sdk/libbootstrappercpp.bc" "$(IlcPath)/sdk/libSystem.Private.CoreLib.Native.bc" </EmccArgs> + <EmccArgs Condition="'$(Platform)'=='wasm'">$(EmccArgs) "$(IlcPath)/sdk/libPortableRuntime.bc" "$(IlcPath)/sdk/libbootstrappercpp.bc" "$(IlcPath)/sdk/libSystem.Private.CoreLib.Native.bc" </EmccArgs> <EmccArgs Condition="'$(Configuration)'=='Release'">$(EmccArgs) -O2 --llvm-lto 2</EmccArgs> <EmccArgs Condition="'$(Configuration)'=='Debug'">$(EmccArgs) -g3</EmccArgs> </PropertyGroup> diff --git a/src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs b/src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs index 2701adb4b..e828d6d97 100644 --- a/src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs +++ b/src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs @@ -54,6 +54,7 @@ namespace Internal.IL private readonly byte[] _ilBytes; private MethodDebugInformation _debugInformation; private LLVMMetadataRef _debugFunction; + private TypeDesc _constrainedType = null; /// <summary> /// Stack of values pushed onto the IL stack: locals, arguments, values, function pointer, ... @@ -503,6 +504,10 @@ namespace Internal.IL private void EndImportingInstruction() { + // If this was constrained used in a call, it's already been cleared, + // but if it was on some other instruction, it shoudln't carry forward + _constrainedType = null; + // Reset the debug position so it doesn't end up applying to the wrong instructions LLVM.SetCurrentDebugLocation(_builder, default(LLVMValueRef)); } @@ -1198,7 +1203,7 @@ namespace Internal.IL } } - if (callee.IsPInvoke || (callee.IsInternalCall && callee.HasCustomAttribute("System.Runtime", "RuntimeImportAttribute"))) + if (callee.IsRawPInvoke() || (callee.IsInternalCall && callee.HasCustomAttribute("System.Runtime", "RuntimeImportAttribute"))) { ImportRawPInvoke(callee); return; @@ -1267,12 +1272,6 @@ namespace Internal.IL } } - // we don't really have virtual call support, but we'll treat it as direct for now - if (opcode != ILOpcode.call && opcode != ILOpcode.callvirt && opcode != ILOpcode.newobj) - { - throw new NotImplementedException(); - } - if (opcode == ILOpcode.newobj && callee.OwningType.IsDelegate) { FunctionPointerEntry functionPointer = ((FunctionPointerEntry)_stack.Peek()); @@ -1284,10 +1283,12 @@ namespace Internal.IL } } - HandleCall(callee, callee.Signature, opcode); + TypeDesc localConstrainedType = _constrainedType; + _constrainedType = null; + HandleCall(callee, callee.Signature, opcode, localConstrainedType); } - private LLVMValueRef LLVMFunctionForMethod(MethodDesc callee, StackEntry thisPointer, bool isCallVirt) + private LLVMValueRef LLVMFunctionForMethod(MethodDesc callee, StackEntry thisPointer, bool isCallVirt, TypeDesc constrainedType) { string calleeName = _compilation.NameMangler.GetMangledMethodName(callee).ToString(); @@ -1321,17 +1322,23 @@ namespace Internal.IL isValueTypeCall = true; } } - if (callee.OwningType.IsInterface) + + if(constrainedType != null && constrainedType.IsValueType) { - // For value types, devirtualize the call - if (isValueTypeCall) + isValueTypeCall = true; + } + + if (isValueTypeCall) + { + if (constrainedType != null) + { + targetMethod = constrainedType.TryResolveConstraintMethodApprox(callee.OwningType, callee, out _); + } + else if (callee.OwningType.IsInterface) { targetMethod = parameterType.ResolveInterfaceMethodTarget(callee); } - } - else - { - if (isValueTypeCall) + else { targetMethod = parameterType.FindVirtualFunctionTargetMethodOnObjectType(callee); } @@ -1584,7 +1591,7 @@ namespace Internal.IL return false; } - private void HandleCall(MethodDesc callee, MethodSignature signature, ILOpcode opcode = ILOpcode.call, LLVMValueRef calliTarget = default(LLVMValueRef)) + private void HandleCall(MethodDesc callee, MethodSignature signature, ILOpcode opcode = ILOpcode.call, TypeDesc constrainedType = null, LLVMValueRef calliTarget = default(LLVMValueRef)) { var parameterCount = signature.Length + (signature.IsStatic ? 0 : 1); // The last argument is the top of the stack. We need to reverse them and store starting at the first argument @@ -1593,10 +1600,34 @@ namespace Internal.IL { argumentValues[argumentValues.Length - i - 1] = _stack.Pop(); } - PushNonNull(HandleCall(callee, signature, argumentValues, opcode, calliTarget)); + + if (constrainedType != null) + { + if (signature.IsStatic) + { + // Constrained call on static method + ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramSpecific, _method); + } + StackEntry thisByRef = argumentValues[0]; + if (thisByRef.Kind != StackValueKind.ByRef) + { + // Constrained call without byref + ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramSpecific, _method); + } + + // If this is a constrained call and the 'this' pointer is a reference type, it's a byref, + // dereference it before calling. + if (!constrainedType.IsValueType) + { + TypeDesc objectType = thisByRef.Type.GetParameterType(); + argumentValues[0] = new LoadExpressionEntry(StackValueKind.ObjRef, "thisPtr", thisByRef.ValueAsType(objectType, _builder), objectType); + } + } + + PushNonNull(HandleCall(callee, signature, argumentValues, opcode, constrainedType, calliTarget)); } - private ExpressionEntry HandleCall(MethodDesc callee, MethodSignature signature, StackEntry[] argumentValues, ILOpcode opcode = ILOpcode.call, LLVMValueRef calliTarget = default(LLVMValueRef), TypeDesc forcedReturnType = null) + private ExpressionEntry HandleCall(MethodDesc callee, MethodSignature signature, StackEntry[] argumentValues, ILOpcode opcode = ILOpcode.call, TypeDesc constrainedType = null, LLVMValueRef calliTarget = default(LLVMValueRef), TypeDesc forcedReturnType = null) { if (opcode == ILOpcode.callvirt && callee.IsVirtual) { @@ -1687,7 +1718,7 @@ namespace Internal.IL } else { - fn = LLVMFunctionForMethod(callee, signature.IsStatic ? null : argumentValues[0], opcode == ILOpcode.callvirt); + fn = LLVMFunctionForMethod(callee, signature.IsStatic ? null : argumentValues[0], opcode == ILOpcode.callvirt, constrainedType); } LLVMValueRef llvmReturn = LLVM.BuildCall(_builder, fn, llvmArgs.ToArray(), string.Empty); @@ -1955,7 +1986,7 @@ namespace Internal.IL private void ImportCalli(int token) { MethodSignature methodSignature = (MethodSignature)_methodIL.GetObject(token); - HandleCall(null, methodSignature, ILOpcode.calli, ((ExpressionEntry)_stack.Pop()).ValueAsType(LLVM.PointerType(GetLLVMSignatureForMethod(methodSignature), 0), _builder)); + HandleCall(null, methodSignature, ILOpcode.calli, calliTarget: ((ExpressionEntry)_stack.Pop()).ValueAsType(LLVM.PointerType(GetLLVMSignatureForMethod(methodSignature), 0), _builder)); } private void ImportLdFtn(int token, ILOpcode opCode) @@ -1967,7 +1998,7 @@ namespace Internal.IL StackEntry thisPointer = _stack.Pop(); if (method.IsVirtual) { - targetLLVMFunction = LLVMFunctionForMethod(method, thisPointer, true); + targetLLVMFunction = LLVMFunctionForMethod(method, thisPointer, true, null); AddVirtualMethodReference(method); } } @@ -2652,6 +2683,7 @@ namespace Internal.IL private void ImportConstrainedPrefix(int token) { + _constrainedType = (TypeDesc)_methodIL.GetObject(token); } private void ImportNoPrefix(byte mask) diff --git a/tests/src/Simple/HelloWasm/Program.cs b/tests/src/Simple/HelloWasm/Program.cs index 95563912a..98ccb4eb1 100644 --- a/tests/src/Simple/HelloWasm/Program.cs +++ b/tests/src/Simple/HelloWasm/Program.cs @@ -15,6 +15,8 @@ internal static class Program private static int threadStaticInt; private static unsafe int Main(string[] args) { + PrintLine("Starting"); + Add(1, 2); int tempInt = 0; int tempInt2 = 0; @@ -310,6 +312,8 @@ internal static class Program PrintLine("ByReference intrinsics exercise via ReadOnlySpan OK."); } + TestConstrainedClassCalls(); + // This test should remain last to get other results before stopping the debugger PrintLine("Debugger.Break() test: Ok if debugger is open and breaks."); System.Diagnostics.Debugger.Break(); @@ -490,6 +494,113 @@ internal static class Program } } + private static void TestConstrainedClassCalls() + { + string s = "utf-8"; + + PrintString("Direct ToString test: "); + string stringDirectToString = s.ToString(); + if (s.Equals(stringDirectToString)) + { + PrintLine("Ok."); + } + else + { + PrintString("Failed. Returned string:\""); + PrintString(stringDirectToString); + PrintLine("\""); + } + + // Generic calls on methods not defined on object + uint dataFromBase = GenericGetData<MyBase>(new MyBase(11)); + PrintString("Generic call to base class test: "); + if (dataFromBase == 11) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + uint dataFromUnsealed = GenericGetData<UnsealedDerived>(new UnsealedDerived(13)); + PrintString("Generic call to unsealed derived class test: "); + if (dataFromUnsealed == 26) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + uint dataFromSealed = GenericGetData<SealedDerived>(new SealedDerived(15)); + PrintString("Generic call to sealed derived class test: "); + if (dataFromSealed == 45) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + uint dataFromUnsealedAsBase = GenericGetData<MyBase>(new UnsealedDerived(17)); + PrintString("Generic call to unsealed derived class as base test: "); + if (dataFromUnsealedAsBase == 34) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + uint dataFromSealedAsBase = GenericGetData<MyBase>(new SealedDerived(19)); + PrintString("Generic call to sealed derived class as base test: "); + if (dataFromSealedAsBase == 57) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + // Generic calls to methods defined on object + uint hashCodeOfSealedViaGeneric = (uint)GenericGetHashCode<MySealedClass>(new MySealedClass(37)); + PrintString("Generic GetHashCode for sealed class test: "); + if (hashCodeOfSealedViaGeneric == 74) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + uint hashCodeOfUnsealedViaGeneric = (uint)GenericGetHashCode<MyUnsealedClass>(new MyUnsealedClass(41)); + PrintString("Generic GetHashCode for unsealed class test: "); + if (hashCodeOfUnsealedViaGeneric == 82) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + } + + static uint GenericGetData<T>(T obj) where T : MyBase + { + return obj.GetData(); + } + + static int GenericGetHashCode<T>(T obj) + { + return obj.GetHashCode(); + } + [DllImport("*")] private static unsafe extern int printf(byte* str, byte* unused); } @@ -646,3 +757,97 @@ public struct ItfStruct : ITestItf return 4; } } + +public sealed class MySealedClass +{ + uint _data; + + public MySealedClass() + { + _data = 104; + } + + public MySealedClass(uint data) + { + _data = data; + } + + public uint GetData() + { + return _data; + } + + public override int GetHashCode() + { + return (int)_data * 2; + } + + public override string ToString() + { + Program.PrintLine("MySealedClass.ToString called. Data:"); + Program.PrintLine(_data.ToString()); + return _data.ToString(); + } +} + +public class MyUnsealedClass +{ + uint _data; + + public MyUnsealedClass() + { + _data = 24; + } + + public MyUnsealedClass(uint data) + { + _data = data; + } + + public uint GetData() + { + return _data; + } + + public override int GetHashCode() + { + return (int)_data * 2; + } + + public override string ToString() + { + return _data.ToString(); + } +} + +public class MyBase +{ + protected uint _data; + public MyBase(uint data) + { + _data = data; + } + + public virtual uint GetData() + { + return _data; + } +} + +public class UnsealedDerived : MyBase +{ + public UnsealedDerived(uint data) : base(data) { } + public override uint GetData() + { + return _data * 2; + } +} + +public sealed class SealedDerived : MyBase +{ + public SealedDerived(uint data) : base(data) { } + public override uint GetData() + { + return _data * 3; + } +} |