From 296c7356f3df796b62d965bc0d9a315390ef4fcf Mon Sep 17 00:00:00 2001 From: Morgan Brown Date: Thu, 6 Sep 2018 21:53:46 -0700 Subject: Fix constrained calls and marshaling (#6275) Codegen fixes required to make Console.WriteLine work on WebAssembly (when combined with #5987 and a matching CoreFX build): The 'this' pointer for reference types is a byref for constrained virtual calls and needs to be dereferenced. Enable P/Invoke marshaling. This isn't easy to test directly, but we'll notice the effect as we try to use framework code that P/Invokes (in Console.WriteLine, there's SafeHandle marshaling). --- .../Microsoft.NETCore.Native.targets | 2 +- .../src/CodeGen/ILToWebAssemblyImporter.cs | 76 +++++--- tests/src/Simple/HelloWasm/Program.cs | 205 +++++++++++++++++++++ 3 files changed, 260 insertions(+), 23 deletions(-) diff --git a/src/BuildIntegration/Microsoft.NETCore.Native.targets b/src/BuildIntegration/Microsoft.NETCore.Native.targets index 79f3a61b7..66af195d0 100644 --- a/src/BuildIntegration/Microsoft.NETCore.Native.targets +++ b/src/BuildIntegration/Microsoft.NETCore.Native.targets @@ -250,7 +250,7 @@ See the LICENSE file in the project root for more information. "$(NativeObject)" -o "$(NativeBinary)" -s WASM=1 -s ALLOW_MEMORY_GROWTH=1 --emrun - $(EmccArgs) "$(IlcPath)/sdk/libPortableRuntime.bc" "$(IlcPath)/sdk/libbootstrappercpp.bc" "$(IlcPath)/sdk/libSystem.Private.CoreLib.Native.bc" + $(EmccArgs) "$(IlcPath)/sdk/libPortableRuntime.bc" "$(IlcPath)/sdk/libbootstrappercpp.bc" "$(IlcPath)/sdk/libSystem.Private.CoreLib.Native.bc" $(EmccArgs) -O2 --llvm-lto 2 $(EmccArgs) -g3 diff --git a/src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs b/src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs index 2701adb4b..e828d6d97 100644 --- a/src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs +++ b/src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs @@ -54,6 +54,7 @@ namespace Internal.IL private readonly byte[] _ilBytes; private MethodDebugInformation _debugInformation; private LLVMMetadataRef _debugFunction; + private TypeDesc _constrainedType = null; /// /// Stack of values pushed onto the IL stack: locals, arguments, values, function pointer, ... @@ -503,6 +504,10 @@ namespace Internal.IL private void EndImportingInstruction() { + // If this was constrained used in a call, it's already been cleared, + // but if it was on some other instruction, it shoudln't carry forward + _constrainedType = null; + // Reset the debug position so it doesn't end up applying to the wrong instructions LLVM.SetCurrentDebugLocation(_builder, default(LLVMValueRef)); } @@ -1198,7 +1203,7 @@ namespace Internal.IL } } - if (callee.IsPInvoke || (callee.IsInternalCall && callee.HasCustomAttribute("System.Runtime", "RuntimeImportAttribute"))) + if (callee.IsRawPInvoke() || (callee.IsInternalCall && callee.HasCustomAttribute("System.Runtime", "RuntimeImportAttribute"))) { ImportRawPInvoke(callee); return; @@ -1267,12 +1272,6 @@ namespace Internal.IL } } - // we don't really have virtual call support, but we'll treat it as direct for now - if (opcode != ILOpcode.call && opcode != ILOpcode.callvirt && opcode != ILOpcode.newobj) - { - throw new NotImplementedException(); - } - if (opcode == ILOpcode.newobj && callee.OwningType.IsDelegate) { FunctionPointerEntry functionPointer = ((FunctionPointerEntry)_stack.Peek()); @@ -1284,10 +1283,12 @@ namespace Internal.IL } } - HandleCall(callee, callee.Signature, opcode); + TypeDesc localConstrainedType = _constrainedType; + _constrainedType = null; + HandleCall(callee, callee.Signature, opcode, localConstrainedType); } - private LLVMValueRef LLVMFunctionForMethod(MethodDesc callee, StackEntry thisPointer, bool isCallVirt) + private LLVMValueRef LLVMFunctionForMethod(MethodDesc callee, StackEntry thisPointer, bool isCallVirt, TypeDesc constrainedType) { string calleeName = _compilation.NameMangler.GetMangledMethodName(callee).ToString(); @@ -1321,17 +1322,23 @@ namespace Internal.IL isValueTypeCall = true; } } - if (callee.OwningType.IsInterface) + + if(constrainedType != null && constrainedType.IsValueType) { - // For value types, devirtualize the call - if (isValueTypeCall) + isValueTypeCall = true; + } + + if (isValueTypeCall) + { + if (constrainedType != null) + { + targetMethod = constrainedType.TryResolveConstraintMethodApprox(callee.OwningType, callee, out _); + } + else if (callee.OwningType.IsInterface) { targetMethod = parameterType.ResolveInterfaceMethodTarget(callee); } - } - else - { - if (isValueTypeCall) + else { targetMethod = parameterType.FindVirtualFunctionTargetMethodOnObjectType(callee); } @@ -1584,7 +1591,7 @@ namespace Internal.IL return false; } - private void HandleCall(MethodDesc callee, MethodSignature signature, ILOpcode opcode = ILOpcode.call, LLVMValueRef calliTarget = default(LLVMValueRef)) + private void HandleCall(MethodDesc callee, MethodSignature signature, ILOpcode opcode = ILOpcode.call, TypeDesc constrainedType = null, LLVMValueRef calliTarget = default(LLVMValueRef)) { var parameterCount = signature.Length + (signature.IsStatic ? 0 : 1); // The last argument is the top of the stack. We need to reverse them and store starting at the first argument @@ -1593,10 +1600,34 @@ namespace Internal.IL { argumentValues[argumentValues.Length - i - 1] = _stack.Pop(); } - PushNonNull(HandleCall(callee, signature, argumentValues, opcode, calliTarget)); + + if (constrainedType != null) + { + if (signature.IsStatic) + { + // Constrained call on static method + ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramSpecific, _method); + } + StackEntry thisByRef = argumentValues[0]; + if (thisByRef.Kind != StackValueKind.ByRef) + { + // Constrained call without byref + ThrowHelper.ThrowInvalidProgramException(ExceptionStringID.InvalidProgramSpecific, _method); + } + + // If this is a constrained call and the 'this' pointer is a reference type, it's a byref, + // dereference it before calling. + if (!constrainedType.IsValueType) + { + TypeDesc objectType = thisByRef.Type.GetParameterType(); + argumentValues[0] = new LoadExpressionEntry(StackValueKind.ObjRef, "thisPtr", thisByRef.ValueAsType(objectType, _builder), objectType); + } + } + + PushNonNull(HandleCall(callee, signature, argumentValues, opcode, constrainedType, calliTarget)); } - private ExpressionEntry HandleCall(MethodDesc callee, MethodSignature signature, StackEntry[] argumentValues, ILOpcode opcode = ILOpcode.call, LLVMValueRef calliTarget = default(LLVMValueRef), TypeDesc forcedReturnType = null) + private ExpressionEntry HandleCall(MethodDesc callee, MethodSignature signature, StackEntry[] argumentValues, ILOpcode opcode = ILOpcode.call, TypeDesc constrainedType = null, LLVMValueRef calliTarget = default(LLVMValueRef), TypeDesc forcedReturnType = null) { if (opcode == ILOpcode.callvirt && callee.IsVirtual) { @@ -1687,7 +1718,7 @@ namespace Internal.IL } else { - fn = LLVMFunctionForMethod(callee, signature.IsStatic ? null : argumentValues[0], opcode == ILOpcode.callvirt); + fn = LLVMFunctionForMethod(callee, signature.IsStatic ? null : argumentValues[0], opcode == ILOpcode.callvirt, constrainedType); } LLVMValueRef llvmReturn = LLVM.BuildCall(_builder, fn, llvmArgs.ToArray(), string.Empty); @@ -1955,7 +1986,7 @@ namespace Internal.IL private void ImportCalli(int token) { MethodSignature methodSignature = (MethodSignature)_methodIL.GetObject(token); - HandleCall(null, methodSignature, ILOpcode.calli, ((ExpressionEntry)_stack.Pop()).ValueAsType(LLVM.PointerType(GetLLVMSignatureForMethod(methodSignature), 0), _builder)); + HandleCall(null, methodSignature, ILOpcode.calli, calliTarget: ((ExpressionEntry)_stack.Pop()).ValueAsType(LLVM.PointerType(GetLLVMSignatureForMethod(methodSignature), 0), _builder)); } private void ImportLdFtn(int token, ILOpcode opCode) @@ -1967,7 +1998,7 @@ namespace Internal.IL StackEntry thisPointer = _stack.Pop(); if (method.IsVirtual) { - targetLLVMFunction = LLVMFunctionForMethod(method, thisPointer, true); + targetLLVMFunction = LLVMFunctionForMethod(method, thisPointer, true, null); AddVirtualMethodReference(method); } } @@ -2652,6 +2683,7 @@ namespace Internal.IL private void ImportConstrainedPrefix(int token) { + _constrainedType = (TypeDesc)_methodIL.GetObject(token); } private void ImportNoPrefix(byte mask) diff --git a/tests/src/Simple/HelloWasm/Program.cs b/tests/src/Simple/HelloWasm/Program.cs index 95563912a..98ccb4eb1 100644 --- a/tests/src/Simple/HelloWasm/Program.cs +++ b/tests/src/Simple/HelloWasm/Program.cs @@ -15,6 +15,8 @@ internal static class Program private static int threadStaticInt; private static unsafe int Main(string[] args) { + PrintLine("Starting"); + Add(1, 2); int tempInt = 0; int tempInt2 = 0; @@ -310,6 +312,8 @@ internal static class Program PrintLine("ByReference intrinsics exercise via ReadOnlySpan OK."); } + TestConstrainedClassCalls(); + // This test should remain last to get other results before stopping the debugger PrintLine("Debugger.Break() test: Ok if debugger is open and breaks."); System.Diagnostics.Debugger.Break(); @@ -490,6 +494,113 @@ internal static class Program } } + private static void TestConstrainedClassCalls() + { + string s = "utf-8"; + + PrintString("Direct ToString test: "); + string stringDirectToString = s.ToString(); + if (s.Equals(stringDirectToString)) + { + PrintLine("Ok."); + } + else + { + PrintString("Failed. Returned string:\""); + PrintString(stringDirectToString); + PrintLine("\""); + } + + // Generic calls on methods not defined on object + uint dataFromBase = GenericGetData(new MyBase(11)); + PrintString("Generic call to base class test: "); + if (dataFromBase == 11) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + uint dataFromUnsealed = GenericGetData(new UnsealedDerived(13)); + PrintString("Generic call to unsealed derived class test: "); + if (dataFromUnsealed == 26) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + uint dataFromSealed = GenericGetData(new SealedDerived(15)); + PrintString("Generic call to sealed derived class test: "); + if (dataFromSealed == 45) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + uint dataFromUnsealedAsBase = GenericGetData(new UnsealedDerived(17)); + PrintString("Generic call to unsealed derived class as base test: "); + if (dataFromUnsealedAsBase == 34) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + uint dataFromSealedAsBase = GenericGetData(new SealedDerived(19)); + PrintString("Generic call to sealed derived class as base test: "); + if (dataFromSealedAsBase == 57) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + // Generic calls to methods defined on object + uint hashCodeOfSealedViaGeneric = (uint)GenericGetHashCode(new MySealedClass(37)); + PrintString("Generic GetHashCode for sealed class test: "); + if (hashCodeOfSealedViaGeneric == 74) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + + uint hashCodeOfUnsealedViaGeneric = (uint)GenericGetHashCode(new MyUnsealedClass(41)); + PrintString("Generic GetHashCode for unsealed class test: "); + if (hashCodeOfUnsealedViaGeneric == 82) + { + PrintLine("Ok."); + } + else + { + PrintLine("Failed."); + } + } + + static uint GenericGetData(T obj) where T : MyBase + { + return obj.GetData(); + } + + static int GenericGetHashCode(T obj) + { + return obj.GetHashCode(); + } + [DllImport("*")] private static unsafe extern int printf(byte* str, byte* unused); } @@ -646,3 +757,97 @@ public struct ItfStruct : ITestItf return 4; } } + +public sealed class MySealedClass +{ + uint _data; + + public MySealedClass() + { + _data = 104; + } + + public MySealedClass(uint data) + { + _data = data; + } + + public uint GetData() + { + return _data; + } + + public override int GetHashCode() + { + return (int)_data * 2; + } + + public override string ToString() + { + Program.PrintLine("MySealedClass.ToString called. Data:"); + Program.PrintLine(_data.ToString()); + return _data.ToString(); + } +} + +public class MyUnsealedClass +{ + uint _data; + + public MyUnsealedClass() + { + _data = 24; + } + + public MyUnsealedClass(uint data) + { + _data = data; + } + + public uint GetData() + { + return _data; + } + + public override int GetHashCode() + { + return (int)_data * 2; + } + + public override string ToString() + { + return _data.ToString(); + } +} + +public class MyBase +{ + protected uint _data; + public MyBase(uint data) + { + _data = data; + } + + public virtual uint GetData() + { + return _data; + } +} + +public class UnsealedDerived : MyBase +{ + public UnsealedDerived(uint data) : base(data) { } + public override uint GetData() + { + return _data * 2; + } +} + +public sealed class SealedDerived : MyBase +{ + public SealedDerived(uint data) : base(data) { } + public override uint GetData() + { + return _data * 3; + } +} -- cgit v1.2.3