Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/dotnet/runtime.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Comparers.cs95
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs504
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStubContext.cs (renamed from src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStub.cs)176
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratedDllImportData.cs48
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratorDiagnostics.cs108
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/ManagedTypeInfo.cs74
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BlittableMarshaller.cs2
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BoolMarshaller.cs2
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/CharMarshaller.cs2
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/DelegateMarshaller.cs2
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/Forwarder.cs4
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/HResultExceptionMarshaller.cs2
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/MarshallingGenerator.cs99
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs8
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/MarshallingAttributeInfo.cs94
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeContext.cs4
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs175
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypePositionInfo.cs44
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypeSymbolExtensions.cs4
-rw-r--r--src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/UnreachableException.cs13
-rw-r--r--src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/Ancillary.Interop.csproj1
-rw-r--r--src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/IncrementalGenerationTests.cs203
-rw-r--r--src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/TestUtils.cs56
23 files changed, 1131 insertions, 589 deletions
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Comparers.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Comparers.cs
new file mode 100644
index 00000000000..e1ded7b4ab1
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Comparers.cs
@@ -0,0 +1,95 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using System;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Linq;
+
+namespace Microsoft.Interop
+{
+ internal static class Comparers
+ {
+ /// <summary>
+ /// Comparer for the set of all of the generated stubs and diagnostics generated for each of them.
+ /// </summary>
+ public static readonly IEqualityComparer<ImmutableArray<(string, ImmutableArray<Diagnostic>)>> GeneratedSourceSet = new ImmutableArraySequenceEqualComparer<(string, ImmutableArray<Diagnostic>)>(new CustomValueTupleElementComparer<string, ImmutableArray<Diagnostic>>(EqualityComparer<string>.Default, new ImmutableArraySequenceEqualComparer<Diagnostic>(EqualityComparer<Diagnostic>.Default)));
+
+ /// <summary>
+ /// Comparer for an individual generated stub source as a string and the generated diagnostics for the stub.
+ /// </summary>
+ public static readonly IEqualityComparer<(string, ImmutableArray<Diagnostic>)> GeneratedSource = new CustomValueTupleElementComparer<string, ImmutableArray<Diagnostic>>(EqualityComparer<string>.Default, new ImmutableArraySequenceEqualComparer<Diagnostic>(EqualityComparer<Diagnostic>.Default));
+
+ /// <summary>
+ /// Comparer for an individual generated stub source as a syntax tree and the generated diagnostics for the stub.
+ /// </summary>
+ public static readonly IEqualityComparer<(MemberDeclarationSyntax Syntax, ImmutableArray<Diagnostic> Diagnostics)> GeneratedSyntax = new CustomValueTupleElementComparer<MemberDeclarationSyntax, ImmutableArray<Diagnostic>>(new SyntaxEquivalentComparer(), new ImmutableArraySequenceEqualComparer<Diagnostic>(EqualityComparer<Diagnostic>.Default));
+
+ /// <summary>
+ /// Comparer for the context used to generate a stub and the original user-provided syntax that triggered stub creation.
+ /// </summary>
+ public static readonly IEqualityComparer<(MethodDeclarationSyntax Syntax, DllImportGenerator.IncrementalStubGenerationContext StubContext)> CalculatedContextWithSyntax = new CustomValueTupleElementComparer<MethodDeclarationSyntax, DllImportGenerator.IncrementalStubGenerationContext>(new SyntaxEquivalentComparer(), EqualityComparer<DllImportGenerator.IncrementalStubGenerationContext>.Default);
+ }
+
+ /// <summary>
+ /// Generic comparer to compare two <see cref="ImmutableArray{T}"/> instances element by element.
+ /// </summary>
+ /// <typeparam name="T">The type of immutable array element.</typeparam>
+ internal class ImmutableArraySequenceEqualComparer<T> : IEqualityComparer<ImmutableArray<T>>
+ {
+ private readonly IEqualityComparer<T> elementComparer;
+
+ /// <summary>
+ /// Creates an <see cref="ImmutableArraySequenceEqualComparer{T}"/> with a custom comparer for the elements of the collection.
+ /// </summary>
+ /// <param name="elementComparer">The comparer instance for the collection elements.</param>
+ public ImmutableArraySequenceEqualComparer(IEqualityComparer<T> elementComparer)
+ {
+ this.elementComparer = elementComparer;
+ }
+
+ public bool Equals(ImmutableArray<T> x, ImmutableArray<T> y)
+ {
+ return x.SequenceEqual(y, elementComparer);
+ }
+
+ public int GetHashCode(ImmutableArray<T> obj)
+ {
+ throw new UnreachableException();
+ }
+ }
+
+ internal class SyntaxEquivalentComparer : IEqualityComparer<SyntaxNode>
+ {
+ public bool Equals(SyntaxNode x, SyntaxNode y)
+ {
+ return x.IsEquivalentTo(y);
+ }
+
+ public int GetHashCode(SyntaxNode obj)
+ {
+ throw new UnreachableException();
+ }
+ }
+
+ internal class CustomValueTupleElementComparer<T, U> : IEqualityComparer<(T, U)>
+ {
+ private readonly IEqualityComparer<T> item1Comparer;
+ private readonly IEqualityComparer<U> item2Comparer;
+
+ public CustomValueTupleElementComparer(IEqualityComparer<T> item1Comparer, IEqualityComparer<U> item2Comparer)
+ {
+ this.item1Comparer = item1Comparer;
+ this.item2Comparer = item2Comparer;
+ }
+
+ public bool Equals((T, U) x, (T, U) y)
+ {
+ return item1Comparer.Equals(x.Item1, y.Item1) && item2Comparer.Equals(x.Item2, y.Item2);
+ }
+
+ public int GetHashCode((T, U) obj)
+ {
+ throw new UnreachableException();
+ }
+ }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs
index 802858ddb1f..92668a646d4 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs
@@ -1,164 +1,186 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Diagnostics;
using System;
using System.Collections.Generic;
+using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
-
-using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp;
-using Microsoft.CodeAnalysis.CSharp.Syntax;
-using Microsoft.CodeAnalysis.Text;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
namespace Microsoft.Interop
{
[Generator]
- public class DllImportGenerator : ISourceGenerator
+ public class DllImportGenerator : IIncrementalGenerator
{
private const string GeneratedDllImport = nameof(GeneratedDllImport);
private const string GeneratedDllImportAttribute = nameof(GeneratedDllImportAttribute);
private static readonly Version MinimumSupportedFrameworkVersion = new Version(5, 0);
- public void Execute(GeneratorExecutionContext context)
+ internal sealed record IncrementalStubGenerationContext(DllImportStubContext StubContext, ImmutableArray<AttributeSyntax> ForwardedAttributes, GeneratedDllImportData DllImportData, ImmutableArray<Diagnostic> Diagnostics)
{
- if (context.SyntaxContextReceiver is not SyntaxContextReceiver synRec
- || !synRec.Methods.Any())
+ public bool Equals(IncrementalStubGenerationContext? other)
{
- return;
+ return other is not null
+ && StubContext.Equals(other.StubContext)
+ && DllImportData.Equals(other.DllImportData)
+ && ForwardedAttributes.SequenceEqual(other.ForwardedAttributes, (IEqualityComparer<AttributeSyntax>)new SyntaxEquivalentComparer())
+ && Diagnostics.SequenceEqual(other.Diagnostics);
}
- INamedTypeSymbol? lcidConversionAttrType = context.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
-
- INamedTypeSymbol? suppressGCTransitionAttrType = context.Compilation.GetTypeByMetadataName(TypeNames.SuppressGCTransitionAttribute);
-
- INamedTypeSymbol? unmanagedCallConvAttrType = context.Compilation.GetTypeByMetadataName(TypeNames.UnmanagedCallConvAttribute);
-
- // Fire the start/stop pair for source generation
- using var _ = Diagnostics.Events.SourceGenerationStartStop(synRec.Methods.Count);
-
- // Store a mapping between SyntaxTree and SemanticModel.
- // SemanticModels cache results and since we could be looking at
- // method declarations in the same SyntaxTree we want to benefit from
- // this caching.
- var syntaxToModel = new Dictionary<SyntaxTree, SemanticModel>();
-
- var generatorDiagnostics = new GeneratorDiagnostics(context);
+ public override int GetHashCode()
+ {
+ throw new UnreachableException();
+ }
+ }
- bool isSupported = IsSupportedTargetFramework(context.Compilation, out Version targetFrameworkVersion);
- if (!isSupported)
+ public class IncrementalityTracker
+ {
+ public enum StepName
{
- // We don't return early here, letting the source generation continue.
- // This allows a user to copy generated source and use it as a starting point
- // for manual marshalling if desired.
- generatorDiagnostics.ReportTargetFrameworkNotSupported(MinimumSupportedFrameworkVersion);
+ CalculateStubInformation,
+ GenerateSingleStub,
+ NormalizeWhitespace,
+ ConcatenateStubs,
+ OutputSourceFile
}
- var env = new StubEnvironment(
- context.Compilation,
- isSupported,
- targetFrameworkVersion,
- context.AnalyzerConfigOptions.GlobalOptions,
- context.Compilation.SourceModule.GetAttributes()
- .Any(a => a.AttributeClass?.ToDisplayString() == TypeNames.System_Runtime_CompilerServices_SkipLocalsInitAttribute));
+ public record ExecutedStepInfo(StepName Step, object Input);
- var generatedDllImports = new StringBuilder();
+ private List<ExecutedStepInfo> executedSteps = new();
+ public IEnumerable<ExecutedStepInfo> ExecutedSteps => executedSteps;
- // Mark in source that the file is auto-generated.
- generatedDllImports.AppendLine("// <auto-generated/>");
+ internal void RecordExecutedStep(ExecutedStepInfo step) => executedSteps.Add(step);
+ }
- foreach (SyntaxReference synRef in synRec.Methods)
- {
- var methodSyntax = (MethodDeclarationSyntax)synRef.GetSyntax(context.CancellationToken);
+ /// <summary>
+ /// This property provides a test-only hook to enable testing the incrementality of the source generator.
+ /// This will be removed when https://github.com/dotnet/roslyn/issues/54832 is implemented and can be consumed.
+ /// </summary>
+ public IncrementalityTracker? IncrementalTracker { get; set; }
- // Get the model for the method.
- if (!syntaxToModel.TryGetValue(methodSyntax.SyntaxTree, out SemanticModel sm))
+ public void Initialize(IncrementalGeneratorInitializationContext context)
+ {
+ var methodsToGenerate = context.SyntaxProvider
+ .CreateSyntaxProvider(
+ static (node, ct) => ShouldVisitNode(node),
+ static (context, ct) =>
+ new
+ {
+ Syntax = (MethodDeclarationSyntax)context.Node,
+ Symbol = (IMethodSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node, ct)!
+ })
+ .Where(
+ static modelData => modelData.Symbol.IsStatic && modelData.Symbol.GetAttributes().Any(
+ static attribute => attribute.AttributeClass?.ToDisplayString() == TypeNames.GeneratedDllImportAttribute)
+ );
+
+ var compilationAndTargetFramework = context.CompilationProvider
+ .Select(static (compilation, ct) =>
{
- sm = context.Compilation.GetSemanticModel(methodSyntax.SyntaxTree, ignoreAccessibility: true);
- syntaxToModel.Add(methodSyntax.SyntaxTree, sm);
- }
-
- // Process the method syntax and get its SymbolInfo.
- var methodSymbolInfo = sm.GetDeclaredSymbol(methodSyntax, context.CancellationToken)!;
-
- // Get any attributes of interest on the method
- AttributeData? generatedDllImportAttr = null;
- AttributeData? lcidConversionAttr = null;
- AttributeData? suppressGCTransitionAttribute = null;
- AttributeData? unmanagedCallConvAttribute = null;
-
- foreach (var attr in methodSymbolInfo.GetAttributes())
+ bool isSupported = IsSupportedTargetFramework(compilation, out Version targetFrameworkVersion);
+ return (compilation, isSupported, targetFrameworkVersion);
+ });
+
+ context.RegisterSourceOutput(
+ compilationAndTargetFramework
+ .Combine(methodsToGenerate.Collect()),
+ static (context, data) =>
{
- if (attr.AttributeClass is null)
+ if (!data.Left.isSupported && data.Right.Any())
{
- continue;
+ // We don't block source generation when the TFM is unsupported.
+ // This allows a user to copy generated source and use it as a starting point
+ // for manual marshalling if desired.
+ context.ReportDiagnostic(
+ Diagnostic.Create(
+ GeneratorDiagnostics.TargetFrameworkNotSupported,
+ Location.None,
+ MinimumSupportedFrameworkVersion.ToString(2)));
}
- else if (attr.AttributeClass.ToDisplayString() == TypeNames.GeneratedDllImportAttribute)
+ });
+
+ var stubEnvironment = compilationAndTargetFramework
+ .Combine(context.AnalyzerConfigOptionsProvider)
+ .Select(
+ static (data, ct) =>
+ new StubEnvironment(
+ data.Left.compilation,
+ data.Left.isSupported,
+ data.Left.targetFrameworkVersion,
+ data.Right.GlobalOptions,
+ data.Left.compilation.SourceModule.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.System_Runtime_CompilerServices_SkipLocalsInitAttribute))
+ );
+
+ var methodSourceAndDiagnostics = methodsToGenerate
+ .Combine(stubEnvironment)
+ .Select(static (data, ct) => new
+ {
+ data.Left.Syntax,
+ data.Left.Symbol,
+ Environment = data.Right
+ })
+ .Select(
+ (data, ct) =>
{
- generatedDllImportAttr = attr;
+ IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.CalculateStubInformation, data));
+ return (data.Syntax, StubContext: CalculateStubInformation(data.Syntax, data.Symbol, data.Environment, ct));
}
- else if (lcidConversionAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, lcidConversionAttrType))
+ )
+ .WithComparer(Comparers.CalculatedContextWithSyntax)
+ .Combine(context.AnalyzerConfigOptionsProvider)
+ .Select(
+ (data, ct) =>
{
- lcidConversionAttr = attr;
+ IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.GenerateSingleStub, data));
+ return GenerateSource(data.Left.StubContext, data.Left.Syntax, data.Right.GlobalOptions);
}
- else if (suppressGCTransitionAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, suppressGCTransitionAttrType))
+ )
+ .WithComparer(Comparers.GeneratedSyntax)
+ // Handle NormalizeWhitespace as a separate stage for incremental runs since it is an expensive operation.
+ .Select(
+ (data, ct) =>
{
- suppressGCTransitionAttribute = attr;
- }
- else if (unmanagedCallConvAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, unmanagedCallConvAttrType))
+ IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.NormalizeWhitespace, data));
+ return (data.Item1.NormalizeWhitespace().ToFullString(), data.Item2);
+ })
+ .Collect()
+ .WithComparer(Comparers.GeneratedSourceSet)
+ .Select((generatedSources, ct) =>
+ {
+ IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.ConcatenateStubs, generatedSources));
+ StringBuilder source = new();
+ // Mark in source that the file is auto-generated.
+ source.AppendLine("// <auto-generated/>");
+ ImmutableArray<Diagnostic>.Builder diagnostics = ImmutableArray.CreateBuilder<Diagnostic>();
+ foreach (var generated in generatedSources)
{
- unmanagedCallConvAttribute = attr;
+ source.AppendLine(generated.Item1);
+ diagnostics.AddRange(generated.Item2);
}
- }
-
- if (generatedDllImportAttr == null)
- continue;
-
- // Process the GeneratedDllImport attribute
- DllImportStub.GeneratedDllImportData stubDllImportData = this.ProcessGeneratedDllImportAttribute(generatedDllImportAttr);
- Debug.Assert(stubDllImportData is not null);
-
- if (stubDllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.BestFitMapping))
- {
- generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.BestFitMapping));
- }
-
- if (stubDllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.ThrowOnUnmappableChar))
- {
- generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.ThrowOnUnmappableChar));
- }
-
- if (stubDllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.CallingConvention))
- {
- generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.CallingConvention));
- }
+ return (source: source.ToString(), diagnostics: diagnostics.ToImmutable());
+ })
+ .WithComparer(Comparers.GeneratedSource);
- if (lcidConversionAttr != null)
+ context.RegisterSourceOutput(methodSourceAndDiagnostics,
+ (context, data) =>
{
- // Using LCIDConversion with GeneratedDllImport is not supported
- generatorDiagnostics.ReportConfigurationNotSupported(lcidConversionAttr, nameof(TypeNames.LCIDConversionAttribute));
- }
-
- List<AttributeSyntax> forwardedAttributes = GenerateSyntaxForForwardedAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute);
-
- // Create the stub.
- var dllImportStub = DllImportStub.Create(methodSymbolInfo, stubDllImportData!, env, generatorDiagnostics, forwardedAttributes, context.CancellationToken);
-
- PrintGeneratedSource(generatedDllImports, methodSyntax, dllImportStub);
- }
-
- Debug.WriteLine(generatedDllImports.ToString()); // [TODO] Find some way to emit this for debugging - logs?
- context.AddSource("DllImportGenerator.g.cs", SourceText.From(generatedDllImports.ToString(), Encoding.UTF8));
- }
+ IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.OutputSourceFile, data));
+ foreach (var diagnostic in data.Item2)
+ {
+ context.ReportDiagnostic(diagnostic);
+ }
- public void Initialize(GeneratorInitializationContext context)
- {
- context.RegisterForSyntaxNotifications(() => new SyntaxContextReceiver());
+ context.AddSource("GeneratedDllImports.g.cs", data.Item1);
+ });
}
-
- private List<AttributeSyntax> GenerateSyntaxForForwardedAttributes(AttributeData? suppressGCTransitionAttribute, AttributeData? unmanagedCallConvAttribute)
+
+ private static List<AttributeSyntax> GenerateSyntaxForForwardedAttributes(AttributeData? suppressGCTransitionAttribute, AttributeData? unmanagedCallConvAttribute)
{
const string CallConvsField = "CallConvs";
// Manually rehydrate the forwarded attributes with fully qualified types so we don't have to worry about any using directives.
@@ -196,7 +218,7 @@ namespace Microsoft.Interop
return attributes;
}
- private SyntaxTokenList StripTriviaFromModifiers(SyntaxTokenList tokenList)
+ private static SyntaxTokenList StripTriviaFromModifiers(SyntaxTokenList tokenList)
{
SyntaxToken[] strippedTokens = new SyntaxToken[tokenList.Count];
for (int i = 0; i < tokenList.Count; i++)
@@ -206,7 +228,7 @@ namespace Microsoft.Interop
return new SyntaxTokenList(strippedTokens);
}
- private TypeDeclarationSyntax CreateTypeDeclarationWithoutTrivia(TypeDeclarationSyntax typeDeclaration)
+ private static TypeDeclarationSyntax CreateTypeDeclarationWithoutTrivia(TypeDeclarationSyntax typeDeclaration)
{
return TypeDeclaration(
typeDeclaration.Kind(),
@@ -215,17 +237,17 @@ namespace Microsoft.Interop
.WithModifiers(typeDeclaration.Modifiers);
}
- private void PrintGeneratedSource(
- StringBuilder builder,
+ private static MemberDeclarationSyntax PrintGeneratedSource(
MethodDeclarationSyntax userDeclaredMethod,
- DllImportStub stub)
+ DllImportStubContext stub,
+ BlockSyntax stubCode)
{
// Create stub function
var stubMethod = MethodDeclaration(stub.StubReturnType, userDeclaredMethod.Identifier)
- .AddAttributeLists(stub.AdditionalAttributes)
+ .AddAttributeLists(stub.AdditionalAttributes.ToArray())
.WithModifiers(StripTriviaFromModifiers(userDeclaredMethod.Modifiers))
.WithParameterList(ParameterList(SeparatedList(stub.StubParameters)))
- .WithBody(stub.StubCode);
+ .WithBody(stubCode);
// Stub should have at least one containing type
Debug.Assert(stub.StubContainingTypes.Any());
@@ -250,7 +272,7 @@ namespace Microsoft.Interop
.AddMembers(toPrint);
}
- builder.AppendLine(toPrint.NormalizeWhitespace().ToString());
+ return toPrint;
}
private static bool IsSupportedTargetFramework(Compilation compilation, out Version version)
@@ -270,10 +292,8 @@ namespace Microsoft.Interop
};
}
- private DllImportStub.GeneratedDllImportData ProcessGeneratedDllImportAttribute(AttributeData attrData)
+ private static GeneratedDllImportData ProcessGeneratedDllImportAttribute(AttributeData attrData)
{
- var stubDllImportData = new DllImportStub.GeneratedDllImportData();
-
// Found the GeneratedDllImport, but it has an error so report the error.
// This is most likely an issue with targeting an incorrect TFM.
if (attrData.AttributeClass?.TypeKind is null or TypeKind.Error)
@@ -282,8 +302,21 @@ namespace Microsoft.Interop
throw new InvalidProgramException();
}
- // Populate the DllImport data from the GeneratedDllImportAttribute attribute.
- stubDllImportData.ModuleName = attrData.ConstructorArguments[0].Value!.ToString();
+
+ // Default values for these properties are based on the
+ // documented semanatics of DllImportAttribute:
+ // - https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute
+ DllImportMember userDefinedValues = DllImportMember.None;
+ bool bestFitMapping = false;
+ CallingConvention callingConvention = CallingConvention.Winapi;
+ CharSet charSet = CharSet.Ansi;
+ string? entryPoint = null;
+ bool exactSpelling = false; // VB has different and unusual default behavior here.
+ bool preserveSig = true;
+ bool setLastError = false;
+ bool throwOnUnmappableChar = false;
+
+ var stubDllImportData = new GeneratedDllImportData(attrData.ConstructorArguments[0].Value!.ToString());
// All other data on attribute is defined as NamedArguments.
foreach (var namedArg in attrData.NamedArguments)
@@ -293,96 +326,179 @@ namespace Microsoft.Interop
default:
Debug.Fail($"An unknown member was found on {GeneratedDllImport}");
continue;
- case nameof(DllImportStub.GeneratedDllImportData.BestFitMapping):
- stubDllImportData.BestFitMapping = (bool)namedArg.Value.Value!;
- stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.BestFitMapping;
+ case nameof(GeneratedDllImportData.BestFitMapping):
+ userDefinedValues |= DllImportMember.BestFitMapping;
+ bestFitMapping = (bool)namedArg.Value.Value!;
break;
- case nameof(DllImportStub.GeneratedDllImportData.CallingConvention):
- stubDllImportData.CallingConvention = (CallingConvention)namedArg.Value.Value!;
- stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.CallingConvention;
+ case nameof(GeneratedDllImportData.CallingConvention):
+ userDefinedValues |= DllImportMember.CallingConvention;
+ callingConvention = (CallingConvention)namedArg.Value.Value!;
break;
- case nameof(DllImportStub.GeneratedDllImportData.CharSet):
- stubDllImportData.CharSet = (CharSet)namedArg.Value.Value!;
- stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.CharSet;
+ case nameof(GeneratedDllImportData.CharSet):
+ userDefinedValues |= DllImportMember.CharSet;
+ charSet = (CharSet)namedArg.Value.Value!;
break;
- case nameof(DllImportStub.GeneratedDllImportData.EntryPoint):
- stubDllImportData.EntryPoint = (string)namedArg.Value.Value!;
- stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.EntryPoint;
+ case nameof(GeneratedDllImportData.EntryPoint):
+ userDefinedValues |= DllImportMember.EntryPoint;
+ entryPoint = (string)namedArg.Value.Value!;
break;
- case nameof(DllImportStub.GeneratedDllImportData.ExactSpelling):
- stubDllImportData.ExactSpelling = (bool)namedArg.Value.Value!;
- stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.ExactSpelling;
+ case nameof(GeneratedDllImportData.ExactSpelling):
+ userDefinedValues |= DllImportMember.ExactSpelling;
+ exactSpelling = (bool)namedArg.Value.Value!;
break;
- case nameof(DllImportStub.GeneratedDllImportData.PreserveSig):
- stubDllImportData.PreserveSig = (bool)namedArg.Value.Value!;
- stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.PreserveSig;
+ case nameof(GeneratedDllImportData.PreserveSig):
+ userDefinedValues |= DllImportMember.PreserveSig;
+ preserveSig = (bool)namedArg.Value.Value!;
break;
- case nameof(DllImportStub.GeneratedDllImportData.SetLastError):
- stubDllImportData.SetLastError = (bool)namedArg.Value.Value!;
- stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.SetLastError;
+ case nameof(GeneratedDllImportData.SetLastError):
+ userDefinedValues |= DllImportMember.SetLastError;
+ setLastError = (bool)namedArg.Value.Value!;
break;
- case nameof(DllImportStub.GeneratedDllImportData.ThrowOnUnmappableChar):
- stubDllImportData.ThrowOnUnmappableChar = (bool)namedArg.Value.Value!;
- stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.ThrowOnUnmappableChar;
+ case nameof(GeneratedDllImportData.ThrowOnUnmappableChar):
+ userDefinedValues |= DllImportMember.ThrowOnUnmappableChar;
+ throwOnUnmappableChar = (bool)namedArg.Value.Value!;
break;
}
}
- return stubDllImportData;
+ return new GeneratedDllImportData(attrData.ConstructorArguments[0].Value!.ToString())
+ {
+ IsUserDefined = userDefinedValues,
+ BestFitMapping = bestFitMapping,
+ CallingConvention = callingConvention,
+ CharSet = charSet,
+ EntryPoint = entryPoint,
+ ExactSpelling = exactSpelling,
+ PreserveSig = preserveSig,
+ SetLastError = setLastError,
+ ThrowOnUnmappableChar = throwOnUnmappableChar
+ };
}
-
- private class SyntaxContextReceiver : ISyntaxContextReceiver
+ private static IncrementalStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
{
- public ICollection<SyntaxReference> Methods { get; } = new List<SyntaxReference>();
-
- public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
+ INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
+ INamedTypeSymbol? suppressGCTransitionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.SuppressGCTransitionAttribute);
+ INamedTypeSymbol? unmanagedCallConvAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.UnmanagedCallConvAttribute);
+ // Get any attributes of interest on the method
+ AttributeData? generatedDllImportAttr = null;
+ AttributeData? lcidConversionAttr = null;
+ AttributeData? suppressGCTransitionAttribute = null;
+ AttributeData? unmanagedCallConvAttribute = null;
+ foreach (var attr in symbol.GetAttributes())
{
- SyntaxNode syntaxNode = context.Node;
-
- // We only support C# method declarations.
- if (syntaxNode.Language != LanguageNames.CSharp
- || !syntaxNode.IsKind(SyntaxKind.MethodDeclaration))
+ if (attr.AttributeClass is not null
+ && attr.AttributeClass.ToDisplayString() == TypeNames.GeneratedDllImportAttribute)
{
- return;
+ generatedDllImportAttr = attr;
}
-
- var methodSyntax = (MethodDeclarationSyntax)syntaxNode;
-
- // Verify the method has no generic types or defined implementation
- // and is marked static and partial.
- if (!(methodSyntax.TypeParameterList is null)
- || !(methodSyntax.Body is null)
- || !methodSyntax.Modifiers.Any(SyntaxKind.StaticKeyword)
- || !methodSyntax.Modifiers.Any(SyntaxKind.PartialKeyword))
+ else if (lcidConversionAttrType != null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, lcidConversionAttrType))
{
- return;
+ lcidConversionAttr = attr;
}
-
- // Verify that the types the method is declared in are marked partial.
- for (SyntaxNode? parentNode = methodSyntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent)
+ else if (suppressGCTransitionAttrType != null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, suppressGCTransitionAttrType))
{
- if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword))
- {
- return;
- }
+ suppressGCTransitionAttribute = attr;
}
+ else if (unmanagedCallConvAttrType != null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, unmanagedCallConvAttrType))
+ {
+ unmanagedCallConvAttribute = attr;
+ }
+ }
+
+ Debug.Assert(generatedDllImportAttr is not null);
+
+ var generatorDiagnostics = new GeneratorDiagnostics();
+
+ // Process the GeneratedDllImport attribute
+ GeneratedDllImportData stubDllImportData = ProcessGeneratedDllImportAttribute(generatedDllImportAttr!);
+
+ if (stubDllImportData.IsUserDefined.HasFlag(DllImportMember.BestFitMapping))
+ {
+ generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr!, nameof(GeneratedDllImportData.BestFitMapping));
+ }
+
+ if (stubDllImportData.IsUserDefined.HasFlag(DllImportMember.ThrowOnUnmappableChar))
+ {
+ generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr!, nameof(GeneratedDllImportData.ThrowOnUnmappableChar));
+ }
+
+ if (stubDllImportData.IsUserDefined.HasFlag(DllImportMember.CallingConvention))
+ {
+ generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr!, nameof(GeneratedDllImportData.CallingConvention));
+ }
+
+ if (lcidConversionAttr != null)
+ {
+ // Using LCIDConversion with GeneratedDllImport is not supported
+ generatorDiagnostics.ReportConfigurationNotSupported(lcidConversionAttr, nameof(TypeNames.LCIDConversionAttribute));
+ }
+ List<AttributeSyntax> additionalAttributes = GenerateSyntaxForForwardedAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute);
- // Check if the method is marked with the GeneratedDllImport attribute.
- foreach (AttributeListSyntax listSyntax in methodSyntax.AttributeLists)
+ // Create the stub.
+ var dllImportStub = DllImportStubContext.Create(symbol, stubDllImportData, environment, generatorDiagnostics, ct);
+
+ return new IncrementalStubGenerationContext(dllImportStub, additionalAttributes.ToImmutableArray(), stubDllImportData, generatorDiagnostics.Diagnostics.ToImmutableArray());
+ }
+
+ private (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateSource(
+ IncrementalStubGenerationContext dllImportStub,
+ MethodDeclarationSyntax originalSyntax,
+ AnalyzerConfigOptions options)
+ {
+ var diagnostics = new GeneratorDiagnostics();
+
+ // Generate stub code
+ var stubGenerator = new StubCodeGenerator(
+ dllImportStub.DllImportData,
+ dllImportStub.StubContext.ElementTypeInformation,
+ options,
+ (elementInfo, ex) => diagnostics.ReportMarshallingNotSupported(originalSyntax, elementInfo, ex.NotSupportedDetails));
+
+ ImmutableArray<AttributeSyntax> forwardedAttributes = dllImportStub.ForwardedAttributes;
+
+ var code = stubGenerator.GenerateBody(originalSyntax.Identifier.Text, forwardedAttributes: forwardedAttributes.Length != 0 ? AttributeList(SeparatedList(forwardedAttributes)) : null);
+
+ return (PrintGeneratedSource(originalSyntax, dllImportStub.StubContext, code), dllImportStub.Diagnostics.AddRange(diagnostics.Diagnostics));
+ }
+
+ private static bool ShouldVisitNode(SyntaxNode syntaxNode)
+ {
+ // We only support C# method declarations.
+ if (syntaxNode.Language != LanguageNames.CSharp
+ || !syntaxNode.IsKind(SyntaxKind.MethodDeclaration))
+ {
+ return false;
+ }
+
+ var methodSyntax = (MethodDeclarationSyntax)syntaxNode;
+
+ // Verify the method has no generic types or defined implementation
+ // and is marked static and partial.
+ if (methodSyntax.TypeParameterList is not null
+ || methodSyntax.Body is not null
+ || !methodSyntax.Modifiers.Any(SyntaxKind.StaticKeyword)
+ || !methodSyntax.Modifiers.Any(SyntaxKind.PartialKeyword))
+ {
+ return false;
+ }
+
+ // Verify that the types the method is declared in are marked partial.
+ for (SyntaxNode? parentNode = methodSyntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent)
+ {
+ if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword))
{
- foreach (AttributeSyntax attrSyntax in listSyntax.Attributes)
- {
- SymbolInfo info = context.SemanticModel.GetSymbolInfo(attrSyntax);
- if (info.Symbol is IMethodSymbol attrConstructor
- && attrConstructor.ContainingType.ToDisplayString() == TypeNames.GeneratedDllImportAttribute)
- {
- this.Methods.Add(syntaxNode.GetReference());
- return;
- }
- }
+ return false;
}
}
+
+ // Filter out methods with no attributes early.
+ if (methodSyntax.AttributeLists.Count == 0)
+ {
+ return false;
+ }
+
+ return true;
}
}
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStub.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStubContext.cs
index c82095e5b5a..e86b8fdef82 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStub.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStubContext.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Collections.Immutable;
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading;
@@ -19,101 +20,50 @@ namespace Microsoft.Interop
AnalyzerConfigOptions Options,
bool ModuleSkipLocalsInit);
- internal class DllImportStub
+ internal sealed class DllImportStubContext : IEquatable<DllImportStubContext>
{
- private TypePositionInfo returnTypeInfo;
- private IEnumerable<TypePositionInfo> paramsTypeInfo;
-
// We don't need the warnings around not setting the various
// non-nullable fields/properties on this type in the constructor
// since we always use a property initializer.
#pragma warning disable 8618
- private DllImportStub()
+ private DllImportStubContext()
{
}
#pragma warning restore
+ public ImmutableArray<TypePositionInfo> ElementTypeInformation { get; init; }
+
public string? StubTypeNamespace { get; init; }
- public IEnumerable<TypeDeclarationSyntax> StubContainingTypes { get; init; }
+ public ImmutableArray<TypeDeclarationSyntax> StubContainingTypes { get; init; }
- public TypeSyntax StubReturnType { get => this.returnTypeInfo.ManagedType.AsTypeSyntax(); }
+ public TypeSyntax StubReturnType { get; init; }
public IEnumerable<ParameterSyntax> StubParameters
{
get
{
- foreach (var typeinfo in paramsTypeInfo)
+ foreach (var typeInfo in ElementTypeInformation)
{
- if (typeinfo.ManagedIndex != TypePositionInfo.UnsetIndex
- && typeinfo.ManagedIndex != TypePositionInfo.ReturnIndex)
+ if (typeInfo.ManagedIndex != TypePositionInfo.UnsetIndex
+ && typeInfo.ManagedIndex != TypePositionInfo.ReturnIndex)
{
- yield return Parameter(Identifier(typeinfo.InstanceIdentifier))
- .WithType(typeinfo.ManagedType.AsTypeSyntax())
- .WithModifiers(TokenList(Token(typeinfo.RefKindSyntax)));
+ yield return Parameter(Identifier(typeInfo.InstanceIdentifier))
+ .WithType(typeInfo.ManagedType.Syntax)
+ .WithModifiers(TokenList(Token(typeInfo.RefKindSyntax)));
}
}
}
}
- public BlockSyntax StubCode { get; init; }
+ public ImmutableArray<AttributeListSyntax> AdditionalAttributes { get; init; }
- public AttributeListSyntax[] AdditionalAttributes { get; init; }
-
- /// <summary>
- /// Flags used to indicate members on GeneratedDllImport attribute.
- /// </summary>
- [Flags]
- public enum DllImportMember
- {
- None = 0,
- BestFitMapping = 1 << 0,
- CallingConvention = 1 << 1,
- CharSet = 1 << 2,
- EntryPoint = 1 << 3,
- ExactSpelling = 1 << 4,
- PreserveSig = 1 << 5,
- SetLastError = 1 << 6,
- ThrowOnUnmappableChar = 1 << 7,
- All = ~None
- }
-
- /// <summary>
- /// GeneratedDllImportAttribute data
- /// </summary>
- /// <remarks>
- /// The names of these members map directly to those on the
- /// DllImportAttribute and should not be changed.
- /// </remarks>
- public class GeneratedDllImportData
- {
- public string ModuleName { get; set; } = null!;
-
- /// <summary>
- /// Value set by the user on the original declaration.
- /// </summary>
- public DllImportMember IsUserDefined = DllImportMember.None;
-
- // Default values for the below fields are based on the
- // documented semanatics of DllImportAttribute:
- // - https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute
- public bool BestFitMapping { get; set; } = true;
- public CallingConvention CallingConvention { get; set; } = CallingConvention.Winapi;
- public CharSet CharSet { get; set; } = CharSet.Ansi;
- public string EntryPoint { get; set; } = null!;
- public bool ExactSpelling { get; set; } = false; // VB has different and unusual default behavior here.
- public bool PreserveSig { get; set; } = true;
- public bool SetLastError { get; set; } = false;
- public bool ThrowOnUnmappableChar { get; set; } = false;
- }
-
- public static DllImportStub Create(
+ public static DllImportStubContext Create(
IMethodSymbol method,
GeneratedDllImportData dllImportData,
StubEnvironment env,
GeneratorDiagnostics diagnostics,
- List<AttributeSyntax> forwardedAttributes,
- CancellationToken token = default)
+ CancellationToken token)
{
// Cancel early if requested
token.ThrowIfCancellationRequested();
@@ -127,7 +77,7 @@ namespace Microsoft.Interop
}
// Determine containing type(s)
- var containingTypes = new List<TypeDeclarationSyntax>();
+ var containingTypes = ImmutableArray.CreateBuilder<TypeDeclarationSyntax>();
INamedTypeSymbol currType = method.ContainingType;
while (!(currType is null))
{
@@ -145,6 +95,36 @@ namespace Microsoft.Interop
currType = currType.ContainingType;
}
+ var typeInfos = GenerateTypeInformation(method, dllImportData, diagnostics, env);
+
+ var additionalAttrs = ImmutableArray.CreateBuilder<AttributeListSyntax>();
+
+ // Define additional attributes for the stub definition.
+ if (env.TargetFrameworkVersion >= new Version(5, 0) && !MethodIsSkipLocalsInit(env, method))
+ {
+ additionalAttrs.Add(
+ AttributeList(
+ SeparatedList(new[]
+ {
+ // Adding the skip locals init indiscriminately since the source generator is
+ // targeted at non-blittable method signatures which typically will contain locals
+ // in the generated code.
+ Attribute(ParseName(TypeNames.System_Runtime_CompilerServices_SkipLocalsInitAttribute))
+ })));
+ }
+
+ return new DllImportStubContext()
+ {
+ StubReturnType = method.ReturnType.AsTypeSyntax(),
+ ElementTypeInformation = typeInfos,
+ StubTypeNamespace = stubTypeNamespace,
+ StubContainingTypes = containingTypes.ToImmutable(),
+ AdditionalAttributes = additionalAttrs.ToImmutable(),
+ };
+ }
+
+ private static ImmutableArray<TypePositionInfo> GenerateTypeInformation(IMethodSymbol method, GeneratedDllImportData dllImportData, GeneratorDiagnostics diagnostics, StubEnvironment env)
+ {
// Compute the current default string encoding value.
var defaultEncoding = CharEncoding.Undefined;
if (dllImportData.IsUserDefined.HasFlag(DllImportMember.CharSet))
@@ -163,21 +143,22 @@ namespace Microsoft.Interop
var marshallingAttributeParser = new MarshallingAttributeInfoParser(env.Compilation, diagnostics, defaultInfo, method);
// Determine parameter and return types
- var paramsTypeInfo = new List<TypePositionInfo>();
+ var typeInfos = ImmutableArray.CreateBuilder<TypePositionInfo>();
for (int i = 0; i < method.Parameters.Length; i++)
{
var param = method.Parameters[i];
MarshallingInfo marshallingInfo = marshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes());
var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingInfo, env.Compilation);
- typeInfo = typeInfo with
+ typeInfo = typeInfo with
{
ManagedIndex = i,
- NativeIndex = paramsTypeInfo.Count
+ NativeIndex = typeInfos.Count
};
- paramsTypeInfo.Add(typeInfo);
+ typeInfos.Add(typeInfo);
+
}
- TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, marshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes()));
+ TypePositionInfo retTypeInfo = new(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), marshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes()));
retTypeInfo = retTypeInfo with
{
ManagedIndex = TypePositionInfo.ReturnIndex,
@@ -190,7 +171,7 @@ namespace Microsoft.Interop
if (!dllImportData.PreserveSig && !env.Options.GenerateForwarders())
{
// Create type info for native HRESULT return
- retTypeInfo = TypePositionInfo.CreateForType(env.Compilation.GetSpecialType(SpecialType.System_Int32), NoMarshallingInfo.Instance);
+ retTypeInfo = new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance);
retTypeInfo = retTypeInfo with
{
NativeIndex = TypePositionInfo.ReturnIndex
@@ -206,41 +187,34 @@ namespace Microsoft.Interop
RefKind = RefKind.Out,
RefKindSyntax = SyntaxKind.OutKeyword,
ManagedIndex = TypePositionInfo.ReturnIndex,
- NativeIndex = paramsTypeInfo.Count
+ NativeIndex = typeInfos.Count
};
- paramsTypeInfo.Add(nativeOutInfo);
+ typeInfos.Add(nativeOutInfo);
}
}
+ typeInfos.Add(retTypeInfo);
- // Generate stub code
- var stubGenerator = new StubCodeGenerator(method, dllImportData, paramsTypeInfo, retTypeInfo, diagnostics, env.Options);
- var code = stubGenerator.GenerateSyntax(forwardedAttributes: forwardedAttributes.Count != 0 ? AttributeList(SeparatedList(forwardedAttributes)) : null);
+ return typeInfos.ToImmutable();
+ }
- var additionalAttrs = new List<AttributeListSyntax>();
+ public override bool Equals(object obj)
+ {
+ return obj is DllImportStubContext other && Equals(other);
+ }
- // Define additional attributes for the stub definition.
- if (env.TargetFrameworkVersion >= new Version(5, 0) && !MethodIsSkipLocalsInit(env, method))
- {
- additionalAttrs.Add(
- AttributeList(
- SeparatedList(new []
- {
- // Adding the skip locals init indiscriminately since the source generator is
- // targeted at non-blittable method signatures which typically will contain locals
- // in the generated code.
- Attribute(ParseName(TypeNames.System_Runtime_CompilerServices_SkipLocalsInitAttribute))
- })));
- }
+ public bool Equals(DllImportStubContext other)
+ {
+ return other is not null
+ && StubTypeNamespace == other.StubTypeNamespace
+ && ElementTypeInformation.SequenceEqual(other.ElementTypeInformation)
+ && StubContainingTypes.SequenceEqual(other.StubContainingTypes, (IEqualityComparer<TypeDeclarationSyntax>)new SyntaxEquivalentComparer())
+ && StubReturnType.IsEquivalentTo(other.StubReturnType)
+ && AdditionalAttributes.SequenceEqual(other.AdditionalAttributes, (IEqualityComparer<AttributeListSyntax>)new SyntaxEquivalentComparer());
+ }
- return new DllImportStub()
- {
- returnTypeInfo = managedRetTypeInfo,
- paramsTypeInfo = paramsTypeInfo,
- StubTypeNamespace = stubTypeNamespace,
- StubContainingTypes = containingTypes,
- StubCode = code,
- AdditionalAttributes = additionalAttrs.ToArray(),
- };
+ public override int GetHashCode()
+ {
+ throw new UnreachableException();
}
private static bool MethodIsSkipLocalsInit(StubEnvironment env, IMethodSymbol method)
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratedDllImportData.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratedDllImportData.cs
new file mode 100644
index 00000000000..9a4fc90a125
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratedDllImportData.cs
@@ -0,0 +1,48 @@
+using System;
+using System.Collections.Generic;
+using System.Runtime.InteropServices;
+using System.Text;
+
+namespace Microsoft.Interop
+{
+ /// <summary>
+ /// Flags used to indicate members on GeneratedDllImport attribute.
+ /// </summary>
+ [Flags]
+ public enum DllImportMember
+ {
+ None = 0,
+ BestFitMapping = 1 << 0,
+ CallingConvention = 1 << 1,
+ CharSet = 1 << 2,
+ EntryPoint = 1 << 3,
+ ExactSpelling = 1 << 4,
+ PreserveSig = 1 << 5,
+ SetLastError = 1 << 6,
+ ThrowOnUnmappableChar = 1 << 7,
+ All = ~None
+ }
+
+ /// <summary>
+ /// GeneratedDllImportAttribute data
+ /// </summary>
+ /// <remarks>
+ /// The names of these members map directly to those on the
+ /// DllImportAttribute and should not be changed.
+ /// </remarks>
+ public sealed record GeneratedDllImportData(string ModuleName)
+ {
+ /// <summary>
+ /// Value set by the user on the original declaration.
+ /// </summary>
+ public DllImportMember IsUserDefined { get; init; }
+ public bool BestFitMapping { get; init; }
+ public CallingConvention CallingConvention { get; init; }
+ public CharSet CharSet { get; init; }
+ public string? EntryPoint { get; init; }
+ public bool ExactSpelling { get; init; }
+ public bool PreserveSig { get; init; }
+ public bool SetLastError { get; init; }
+ public bool ThrowOnUnmappableChar { get; init; }
+ }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratorDiagnostics.cs
index 51531f6c1a0..74abfb3c714 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratorDiagnostics.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratorDiagnostics.cs
@@ -5,6 +5,7 @@ using System.Diagnostics;
using System.Linq;
using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace Microsoft.Interop
{
@@ -15,15 +16,20 @@ namespace Microsoft.Interop
DiagnosticDescriptor descriptor,
params object[] args)
{
- IEnumerable<Location> locationsInSource = symbol.Locations.Where(l => l.IsInSource);
- if (!locationsInSource.Any())
- return Diagnostic.Create(descriptor, Location.None, args);
+ return symbol.Locations.CreateDiagnostic(descriptor, args);
+ }
- return Diagnostic.Create(
- descriptor,
- location: locationsInSource.First(),
- additionalLocations: locationsInSource.Skip(1),
- messageArgs: args);
+ public static Diagnostic CreateDiagnostic(
+ this AttributeData attributeData,
+ DiagnosticDescriptor descriptor,
+ params object[] args)
+ {
+ SyntaxReference? syntaxReference = attributeData.ApplicationSyntaxReference;
+ Location location = syntaxReference is not null
+ ? syntaxReference.GetSyntax().GetLocation()
+ : Location.None;
+
+ return location.CreateDiagnostic(descriptor, args);
}
public static Diagnostic CreateDiagnostic(
@@ -43,15 +49,10 @@ namespace Microsoft.Interop
}
public static Diagnostic CreateDiagnostic(
- this AttributeData attributeData,
+ this Location location,
DiagnosticDescriptor descriptor,
params object[] args)
{
- SyntaxReference? syntaxReference = attributeData.ApplicationSyntaxReference;
- Location location = syntaxReference is not null
- ? syntaxReference.GetSyntax().GetLocation()
- : Location.None;
-
return Diagnostic.Create(
descriptor,
location: location.IsInSource ? location : Location.None,
@@ -174,12 +175,9 @@ namespace Microsoft.Interop
isEnabledByDefault: true,
description: GetResourceString(nameof(Resources.TargetFrameworkNotSupportedDescription)));
- private readonly GeneratorExecutionContext context;
+ private readonly List<Diagnostic> diagnostics = new List<Diagnostic>();
- public GeneratorDiagnostics(GeneratorExecutionContext context)
- {
- this.context = context;
- }
+ public IEnumerable<Diagnostic> Diagnostics => diagnostics;
/// <summary>
/// Report diagnostic for configuration that is not supported by the DLL import source generator
@@ -194,14 +192,14 @@ namespace Microsoft.Interop
{
if (unsupportedValue == null)
{
- this.context.ReportDiagnostic(
+ diagnostics.Add(
attributeData.CreateDiagnostic(
GeneratorDiagnostics.ConfigurationNotSupported,
configurationName));
}
else
{
- this.context.ReportDiagnostic(
+ diagnostics.Add(
attributeData.CreateDiagnostic(
GeneratorDiagnostics.ConfigurationValueNotSupported,
unsupportedValue,
@@ -216,30 +214,44 @@ namespace Microsoft.Interop
/// <param name="info">Type info for the parameter/return</param>
/// <param name="notSupportedDetails">[Optional] Specific reason for lack of support</param>
internal void ReportMarshallingNotSupported(
- IMethodSymbol method,
+ MethodDeclarationSyntax method,
TypePositionInfo info,
string? notSupportedDetails)
{
+ Location diagnosticLocation = Location.None;
+ string elementName = string.Empty;
+
+ if (info.IsManagedReturnPosition)
+ {
+ diagnosticLocation = Location.Create(method.SyntaxTree, method.Identifier.Span);
+ elementName = method.Identifier.ValueText;
+ }
+ else
+ {
+ Debug.Assert(info.ManagedIndex <= method.ParameterList.Parameters.Count);
+ ParameterSyntax param = method.ParameterList.Parameters[info.ManagedIndex];
+ diagnosticLocation = Location.Create(param.SyntaxTree, param.Identifier.Span);
+ elementName = param.Identifier.ValueText;
+ }
+
if (!string.IsNullOrEmpty(notSupportedDetails))
{
// Report the specific not-supported reason.
if (info.IsManagedReturnPosition)
{
- this.context.ReportDiagnostic(
- method.CreateDiagnostic(
+ diagnostics.Add(
+ diagnosticLocation.CreateDiagnostic(
GeneratorDiagnostics.ReturnTypeNotSupportedWithDetails,
notSupportedDetails!,
- method.Name));
+ elementName));
}
else
{
- Debug.Assert(info.ManagedIndex <= method.Parameters.Length);
- IParameterSymbol paramSymbol = method.Parameters[info.ManagedIndex];
- this.context.ReportDiagnostic(
- paramSymbol.CreateDiagnostic(
+ diagnostics.Add(
+ diagnosticLocation.CreateDiagnostic(
GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails,
notSupportedDetails!,
- paramSymbol.Name));
+ elementName));
}
}
else if (info.MarshallingAttributeInfo is MarshalAsInfo)
@@ -249,21 +261,19 @@ namespace Microsoft.Interop
// than when there is no attribute and the type itself is not supported.
if (info.IsManagedReturnPosition)
{
- this.context.ReportDiagnostic(
- method.CreateDiagnostic(
+ diagnostics.Add(
+ diagnosticLocation.CreateDiagnostic(
GeneratorDiagnostics.ReturnConfigurationNotSupported,
nameof(System.Runtime.InteropServices.MarshalAsAttribute),
- method.Name));
+ elementName));
}
else
{
- Debug.Assert(info.ManagedIndex <= method.Parameters.Length);
- IParameterSymbol paramSymbol = method.Parameters[info.ManagedIndex];
- this.context.ReportDiagnostic(
- paramSymbol.CreateDiagnostic(
+ diagnostics.Add(
+ diagnosticLocation.CreateDiagnostic(
GeneratorDiagnostics.ParameterConfigurationNotSupported,
nameof(System.Runtime.InteropServices.MarshalAsAttribute),
- paramSymbol.Name));
+ elementName));
}
}
else
@@ -271,21 +281,19 @@ namespace Microsoft.Interop
// Report that the type is not supported
if (info.IsManagedReturnPosition)
{
- this.context.ReportDiagnostic(
- method.CreateDiagnostic(
+ diagnostics.Add(
+ diagnosticLocation.CreateDiagnostic(
GeneratorDiagnostics.ReturnTypeNotSupported,
- method.ReturnType.ToDisplayString(),
- method.Name));
+ info.ManagedType.DiagnosticFormattedName,
+ elementName));
}
else
{
- Debug.Assert(info.ManagedIndex <= method.Parameters.Length);
- IParameterSymbol paramSymbol = method.Parameters[info.ManagedIndex];
- this.context.ReportDiagnostic(
- paramSymbol.CreateDiagnostic(
+ diagnostics.Add(
+ diagnosticLocation.CreateDiagnostic(
GeneratorDiagnostics.ParameterTypeNotSupported,
- paramSymbol.Type.ToDisplayString(),
- paramSymbol.Name));
+ info.ManagedType.DiagnosticFormattedName,
+ elementName));
}
}
}
@@ -295,7 +303,7 @@ namespace Microsoft.Interop
string reasonResourceName,
params string[] reasonArgs)
{
- this.context.ReportDiagnostic(
+ diagnostics.Add(
attributeData.CreateDiagnostic(
GeneratorDiagnostics.MarshallingAttributeConfigurationNotSupported,
new LocalizableResourceString(reasonResourceName, Resources.ResourceManager, typeof(Resources), reasonArgs)));
@@ -307,7 +315,7 @@ namespace Microsoft.Interop
/// <param name="minimumSupportedVersion">Minimum supported version of .NET</param>
public void ReportTargetFrameworkNotSupported(Version minimumSupportedVersion)
{
- this.context.ReportDiagnostic(
+ diagnostics.Add(
Diagnostic.Create(
TargetFrameworkNotSupported,
Location.None,
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/ManagedTypeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/ManagedTypeInfo.cs
new file mode 100644
index 00000000000..9e2bd6f70d8
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/ManagedTypeInfo.cs
@@ -0,0 +1,74 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Microsoft.Interop
+{
+ /// <summary>
+ /// A discriminated union that contains enough info about a managed type to determine a marshalling generator and generate code.
+ /// </summary>
+ internal abstract record ManagedTypeInfo(string FullTypeName, string DiagnosticFormattedName)
+ {
+ public TypeSyntax Syntax { get; } = SyntaxFactory.ParseTypeName(FullTypeName);
+
+ public static ManagedTypeInfo CreateTypeInfoForTypeSymbol(ITypeSymbol type)
+ {
+ string typeName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+ string diagonsticFormattedName = type.ToDisplayString();
+ if (type.SpecialType != SpecialType.None)
+ {
+ return new SpecialTypeInfo(typeName, diagonsticFormattedName, type.SpecialType);
+ }
+ if (type.TypeKind == TypeKind.Enum)
+ {
+ return new EnumTypeInfo(typeName, diagonsticFormattedName, ((INamedTypeSymbol)type).EnumUnderlyingType!.SpecialType);
+ }
+ if (type.TypeKind == TypeKind.Pointer)
+ {
+ return new PointerTypeInfo(typeName, diagonsticFormattedName, IsFunctionPointer: false);
+ }
+ if (type.TypeKind == TypeKind.FunctionPointer)
+ {
+ return new PointerTypeInfo(typeName, diagonsticFormattedName, IsFunctionPointer: true);
+ }
+ if (type.TypeKind == TypeKind.Array && type is IArrayTypeSymbol { IsSZArray: true } arraySymbol)
+ {
+ return new SzArrayType(CreateTypeInfoForTypeSymbol(arraySymbol.ElementType));
+ }
+ if (type.TypeKind == TypeKind.Delegate)
+ {
+ return new DelegateTypeInfo(typeName, diagonsticFormattedName);
+ }
+ return new SimpleManagedTypeInfo(typeName, diagonsticFormattedName);
+ }
+ }
+
+ internal sealed record SpecialTypeInfo(string FullTypeName, string DiagnosticFormattedName, SpecialType SpecialType) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName)
+ {
+ public static readonly SpecialTypeInfo Int32 = new("int", "int", SpecialType.System_Int32);
+ public static readonly SpecialTypeInfo Void = new("void", "void", SpecialType.System_Void);
+
+ public bool Equals(SpecialTypeInfo? other)
+ {
+ return other is not null && SpecialType == other.SpecialType;
+ }
+
+ public override int GetHashCode()
+ {
+ return (int)SpecialType;
+ }
+ }
+
+ internal sealed record EnumTypeInfo(string FullTypeName, string DiagnosticFormattedName, SpecialType UnderlyingType) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName);
+
+ internal sealed record PointerTypeInfo(string FullTypeName, string DiagnosticFormattedName, bool IsFunctionPointer) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName);
+
+ internal sealed record SzArrayType(ManagedTypeInfo ElementTypeInfo) : ManagedTypeInfo($"{ElementTypeInfo.FullTypeName}[]", $"{ElementTypeInfo.DiagnosticFormattedName}[]");
+
+ internal sealed record DelegateTypeInfo(string FullTypeName, string DiagnosticFormattedName) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName);
+
+ internal sealed record SimpleManagedTypeInfo(string FullTypeName, string DiagnosticFormattedName) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName);
+}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BlittableMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BlittableMarshaller.cs
index 0f158706cd0..15fa7774a99 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BlittableMarshaller.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BlittableMarshaller.cs
@@ -11,7 +11,7 @@ namespace Microsoft.Interop
{
public TypeSyntax AsNativeType(TypePositionInfo info)
{
- return info.ManagedType.AsTypeSyntax();
+ return info.ManagedType.Syntax;
}
public ParameterSyntax AsParameter(TypePositionInfo info)
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BoolMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BoolMarshaller.cs
index b659ccbda4b..2b5af8d2a5c 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BoolMarshaller.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BoolMarshaller.cs
@@ -25,7 +25,7 @@ namespace Microsoft.Interop
public TypeSyntax AsNativeType(TypePositionInfo info)
{
- Debug.Assert(info.ManagedType.SpecialType == SpecialType.System_Boolean);
+ Debug.Assert(info.ManagedType is SpecialTypeInfo(_, _, SpecialType.System_Boolean));
return _nativeType;
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/CharMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/CharMarshaller.cs
index b3279390da7..d04a78c2460 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/CharMarshaller.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/CharMarshaller.cs
@@ -33,7 +33,7 @@ namespace Microsoft.Interop
public TypeSyntax AsNativeType(TypePositionInfo info)
{
- Debug.Assert(info.ManagedType.SpecialType == SpecialType.System_Char);
+ Debug.Assert(info.ManagedType is SpecialTypeInfo(_, _, SpecialType.System_Char));
return NativeType;
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/DelegateMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/DelegateMarshaller.cs
index 21a5c80d4cc..b5aca0f4420 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/DelegateMarshaller.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/DelegateMarshaller.cs
@@ -89,7 +89,7 @@ namespace Microsoft.Interop
.WithTypeArgumentList(
TypeArgumentList(
SingletonSeparatedList(
- info.ManagedType.AsTypeSyntax())))),
+ info.ManagedType.Syntax)))),
ArgumentList(SingletonSeparatedList(Argument(IdentifierName(nativeIdentifier))))),
LiteralExpression(SyntaxKind.NullLiteralExpression))));
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/Forwarder.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/Forwarder.cs
index 1bd26150668..e86c422ec48 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/Forwarder.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/Forwarder.cs
@@ -12,7 +12,7 @@ namespace Microsoft.Interop
{
public TypeSyntax AsNativeType(TypePositionInfo info)
{
- return info.ManagedType.AsTypeSyntax();
+ return info.ManagedType.Syntax;
}
private bool TryRehydrateMarshalAsAttribute(TypePositionInfo info, out AttributeSyntax marshalAsAttribute)
@@ -87,7 +87,7 @@ namespace Microsoft.Interop
{
ParameterSyntax param = Parameter(Identifier(info.InstanceIdentifier))
.WithModifiers(TokenList(Token(info.RefKindSyntax)))
- .WithType(info.ManagedType.AsTypeSyntax());
+ .WithType(info.ManagedType.Syntax);
if (TryRehydrateMarshalAsAttribute(info, out AttributeSyntax marshalAsAttribute))
{
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/HResultExceptionMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/HResultExceptionMarshaller.cs
index 4b56c8a3107..09ce81431e6 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/HResultExceptionMarshaller.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/HResultExceptionMarshaller.cs
@@ -15,7 +15,7 @@ namespace Microsoft.Interop
public TypeSyntax AsNativeType(TypePositionInfo info)
{
- Debug.Assert(info.ManagedType.SpecialType == SpecialType.System_Int32);
+ Debug.Assert(info.ManagedType is SpecialTypeInfo(_, _, SpecialType.System_Int32));
return NativeType;
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/MarshallingGenerator.cs
index 946e96fe562..24fb56ad90b 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/MarshallingGenerator.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/MarshallingGenerator.cs
@@ -194,31 +194,31 @@ namespace Microsoft.Interop
if (info.IsNativeReturnPosition && !info.IsManagedReturnPosition)
{
// Use marshaller for native HRESULT return / exception throwing
- System.Diagnostics.Debug.Assert(info.ManagedType.SpecialType == SpecialType.System_Int32);
+ System.Diagnostics.Debug.Assert(info.ManagedType is SpecialTypeInfo { SpecialType: SpecialType.System_Int32 });
return HResultException;
}
switch (info)
{
// Blittable primitives with no marshalling info or with a compatible [MarshalAs] attribute.
- case { ManagedType: { SpecialType: SpecialType.System_SByte }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I1, _) }
- or { ManagedType: { SpecialType: SpecialType.System_Byte }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U1, _) }
- or { ManagedType: { SpecialType: SpecialType.System_Int16 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I2, _) }
- or { ManagedType: { SpecialType: SpecialType.System_UInt16 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U2, _) }
- or { ManagedType: { SpecialType: SpecialType.System_Int32 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I4, _) }
- or { ManagedType: { SpecialType: SpecialType.System_UInt32 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U4, _) }
- or { ManagedType: { SpecialType: SpecialType.System_Int64 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I8, _) }
- or { ManagedType: { SpecialType: SpecialType.System_UInt64 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U8, _) }
- or { ManagedType: { SpecialType: SpecialType.System_IntPtr }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.SysInt, _) }
- or { ManagedType: { SpecialType: SpecialType.System_UIntPtr }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.SysUInt, _) }
- or { ManagedType: { SpecialType: SpecialType.System_Single }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.R4, _) }
- or { ManagedType: { SpecialType: SpecialType.System_Double }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.R8, _) }:
+ case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_SByte }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I1, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Byte }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U1, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Int16 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I2, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_UInt16 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U2, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Int32 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I4, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_UInt32 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U4, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Int64 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I8, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_UInt64 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U8, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_IntPtr }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.SysInt, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_UIntPtr }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.SysUInt, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Single }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.R4, _) }
+ or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Double }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.R8, _) }:
return Blittable;
// Enum with no marshalling info
- case { ManagedType: { TypeKind: TypeKind.Enum }, MarshallingAttributeInfo: NoMarshallingInfo }:
+ case { ManagedType: EnumTypeInfo enumType, MarshallingAttributeInfo: NoMarshallingInfo }:
// Check that the underlying type is not bool or char. C# does not allow this, but ECMA-335 does.
- var underlyingSpecialType = ((INamedTypeSymbol)info.ManagedType).EnumUnderlyingType!.SpecialType;
+ var underlyingSpecialType = enumType.UnderlyingType;
if (underlyingSpecialType == SpecialType.System_Boolean || underlyingSpecialType == SpecialType.System_Char)
{
throw new MarshallingNotSupportedException(info, context);
@@ -226,31 +226,31 @@ namespace Microsoft.Interop
return Blittable;
// Pointer with no marshalling info
- case { ManagedType: { TypeKind: TypeKind.Pointer }, MarshallingAttributeInfo: NoMarshallingInfo }:
+ case { ManagedType: PointerTypeInfo(_, _, IsFunctionPointer:false), MarshallingAttributeInfo: NoMarshallingInfo }:
return Blittable;
// Function pointer with no marshalling info
- case { ManagedType: { TypeKind: TypeKind.FunctionPointer }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }:
+ case { ManagedType: PointerTypeInfo(_, _, IsFunctionPointer: true), MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }:
return Blittable;
- case { ManagedType: { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: NoMarshallingInfo }:
+ case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: NoMarshallingInfo }:
return WinBool; // [Compat] Matching the default for the built-in runtime marshallers.
- case { ManagedType: { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.I1 or UnmanagedType.U1, _) }:
+ case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.I1 or UnmanagedType.U1, _) }:
return ByteBool;
- case { ManagedType: { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.I4 or UnmanagedType.U4 or UnmanagedType.Bool, _) }:
+ case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.I4 or UnmanagedType.U4 or UnmanagedType.Bool, _) }:
return WinBool;
- case { ManagedType: { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.VariantBool, _) }:
+ case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.VariantBool, _) }:
return VariantBool;
- case { ManagedType: { TypeKind: TypeKind.Delegate }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }:
+ case { ManagedType: DelegateTypeInfo, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }:
return Delegate;
- case { MarshallingAttributeInfo: SafeHandleMarshallingInfo }:
+ case { MarshallingAttributeInfo: SafeHandleMarshallingInfo(_, bool isAbstract) }:
if (!context.AdditionalTemporaryStateLivesAcrossStages)
{
throw new MarshallingNotSupportedException(info, context);
}
- if (info.IsByRef && info.ManagedType.IsAbstract)
+ if (info.IsByRef && isAbstract)
{
throw new MarshallingNotSupportedException(info, context)
{
@@ -274,13 +274,13 @@ namespace Microsoft.Interop
// Cases that just match on type must come after the checks that match only on marshalling attribute info.
// The checks below do not account for generic marshalling overrides like [MarshalUsing], so those checks must come first.
- case { ManagedType: { SpecialType: SpecialType.System_Char } }:
+ case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Char } }:
return CreateCharMarshaller(info, context);
- case { ManagedType: { SpecialType: SpecialType.System_String } }:
+ case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_String } }:
return CreateStringMarshaller(info, context);
- case { ManagedType: { SpecialType: SpecialType.System_Void } }:
+ case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Void } }:
return Forwarder;
default:
@@ -403,7 +403,7 @@ namespace Microsoft.Interop
paramInfo,
out int numIndirectionLevels);
- ITypeSymbol type = paramInfo.ManagedType;
+ ManagedTypeInfo type = paramInfo.ManagedType;
MarshallingInfo marshallingInfo = paramInfo.MarshallingAttributeInfo;
for (int i = 0; i < numIndirectionLevels; i++)
@@ -422,7 +422,7 @@ namespace Microsoft.Interop
}
}
- if (!type.IsIntegralType())
+ if (type is not SpecialTypeInfo specialType || !specialType.SpecialType.IsIntegralType())
{
throw new MarshallingNotSupportedException(info, context)
{
@@ -470,14 +470,14 @@ namespace Microsoft.Interop
{
ValidateCustomNativeTypeMarshallingSupported(info, context, marshalInfo);
- ICustomNativeTypeMarshallingStrategy marshallingStrategy = new SimpleCustomNativeTypeMarshalling(marshalInfo.NativeMarshallingType.AsTypeSyntax());
+ ICustomNativeTypeMarshallingStrategy marshallingStrategy = new SimpleCustomNativeTypeMarshalling(marshalInfo.NativeMarshallingType.Syntax);
- if ((marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNativeStackalloc) != 0)
+ if ((marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedToNativeStackalloc) != 0)
{
marshallingStrategy = new StackallocOptimizationMarshalling(marshallingStrategy);
}
- if (ManualTypeMarshallingHelper.HasFreeNativeMethod(marshalInfo.NativeMarshallingType))
+ if ((marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.FreeNativeResources) != 0)
{
marshallingStrategy = new FreeNativeCleanupStrategy(marshallingStrategy);
}
@@ -495,7 +495,7 @@ namespace Microsoft.Interop
IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);
- if ((marshalInfo.MarshallingMethods & SupportedMarshallingMethods.Pinning) != 0)
+ if ((marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedTypePinning) != 0)
{
return new PinnableManagedValueMarshaller(marshallingGenerator);
}
@@ -508,53 +508,54 @@ namespace Microsoft.Interop
// The marshalling method for this type doesn't support marshalling from native to managed,
// but our scenario requires marshalling from native to managed.
if ((info.RefKind == RefKind.Ref || info.RefKind == RefKind.Out || info.IsManagedReturnPosition)
- && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.NativeToManaged) == 0)
+ && (marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.NativeToManaged) == 0)
{
throw new MarshallingNotSupportedException(info, context)
{
- NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingNativeToManagedUnsupported, marshalInfo.NativeMarshallingType.ToDisplayString())
+ NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingNativeToManagedUnsupported, marshalInfo.NativeMarshallingType.FullTypeName)
};
}
// The marshalling method for this type doesn't support marshalling from managed to native by value,
// but our scenario requires marshalling from managed to native by value.
else if (!info.IsByRef
- && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0
- && (context.SingleFrameSpansNativeContext && (marshalInfo.MarshallingMethods & (SupportedMarshallingMethods.Pinning | SupportedMarshallingMethods.ManagedToNativeStackalloc)) == 0))
+ && (marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedToNative) == 0
+ && (context.SingleFrameSpansNativeContext && (marshalInfo.MarshallingFeatures & (CustomMarshallingFeatures.ManagedTypePinning | CustomMarshallingFeatures.ManagedToNativeStackalloc)) == 0))
{
throw new MarshallingNotSupportedException(info, context)
{
- NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.ToDisplayString())
+ NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.FullTypeName)
};
}
// The marshalling method for this type doesn't support marshalling from managed to native by reference,
// but our scenario requires marshalling from managed to native by reference.
// "in" byref supports stack marshalling.
else if (info.RefKind == RefKind.In
- && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0
- && !(context.SingleFrameSpansNativeContext && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNativeStackalloc) != 0))
+ && (marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedToNative) == 0
+ && !(context.SingleFrameSpansNativeContext && (marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedToNativeStackalloc) != 0))
{
throw new MarshallingNotSupportedException(info, context)
{
- NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.ToDisplayString())
+ NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.FullTypeName)
};
}
// The marshalling method for this type doesn't support marshalling from managed to native by reference,
// but our scenario requires marshalling from managed to native by reference.
// "ref" byref marshalling doesn't support stack marshalling
else if (info.RefKind == RefKind.Ref
- && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0)
+ && (marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedToNative) == 0)
{
throw new MarshallingNotSupportedException(info, context)
{
- NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.ToDisplayString())
+ NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.FullTypeName)
};
}
}
private static ICustomNativeTypeMarshallingStrategy DecorateWithValuePropertyStrategy(NativeMarshallingAttributeInfo marshalInfo, ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller)
{
- TypeSyntax valuePropertyTypeSyntax = marshalInfo.ValuePropertyType!.AsTypeSyntax();
- if (ManualTypeMarshallingHelper.FindGetPinnableReference(marshalInfo.NativeMarshallingType) is not null)
+ TypeSyntax valuePropertyTypeSyntax = marshalInfo.ValuePropertyType!.Syntax;
+
+ if ((marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.NativeTypePinning) != 0)
{
return new PinnableMarshallerTypeMarshalling(nativeTypeMarshaller, valuePropertyTypeSyntax);
}
@@ -569,7 +570,7 @@ namespace Microsoft.Interop
AnalyzerConfigOptions options,
ICustomNativeTypeMarshallingStrategy marshallingStrategy)
{
- var elementInfo = TypePositionInfo.CreateForType(collectionInfo.ElementType, collectionInfo.ElementMarshallingInfo) with { ManagedIndex = info.ManagedIndex };
+ var elementInfo = new TypePositionInfo(collectionInfo.ElementType, collectionInfo.ElementMarshallingInfo) { ManagedIndex = info.ManagedIndex };
var elementMarshaller = Create(
elementInfo,
new ContiguousCollectionElementMarshallingCodeContext(StubCodeContext.Stage.Setup, string.Empty, context),
@@ -580,7 +581,7 @@ namespace Microsoft.Interop
if (isBlittable)
{
- marshallingStrategy = new ContiguousBlittableElementCollectionMarshalling(marshallingStrategy, collectionInfo.ElementType.AsTypeSyntax());
+ marshallingStrategy = new ContiguousBlittableElementCollectionMarshalling(marshallingStrategy, collectionInfo.ElementType.Syntax);
}
else
{
@@ -605,7 +606,7 @@ namespace Microsoft.Interop
numElementsExpression,
SizeOfExpression(elementType));
- if (collectionInfo.UseDefaultMarshalling && info.ManagedType is IArrayTypeSymbol { IsSZArray: true })
+ if (collectionInfo.UseDefaultMarshalling && info.ManagedType is SzArrayType)
{
return new ArrayMarshaller(
new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: true),
@@ -616,7 +617,7 @@ namespace Microsoft.Interop
IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);
- if ((collectionInfo.MarshallingMethods & SupportedMarshallingMethods.Pinning) != 0)
+ if ((collectionInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedTypePinning) != 0)
{
return new PinnableManagedValueMarshaller(marshallingGenerator);
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs
index 441d2ce758b..e3b32f94a5c 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs
@@ -83,9 +83,9 @@ namespace Microsoft.Interop
}
var safeHandleCreationExpression = ((SafeHandleMarshallingInfo)info.MarshallingAttributeInfo).AccessibleDefaultConstructor
- ? (ExpressionSyntax)ObjectCreationExpression(info.ManagedType.AsTypeSyntax(), ArgumentList(), initializer: null)
+ ? (ExpressionSyntax)ObjectCreationExpression(info.ManagedType.Syntax, ArgumentList(), initializer: null)
: CastExpression(
- info.ManagedType.AsTypeSyntax(),
+ info.ManagedType.Syntax,
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
@@ -97,7 +97,7 @@ namespace Microsoft.Interop
new []{
Argument(
TypeOfExpression(
- info.ManagedType.AsTypeSyntax())),
+ info.ManagedType.Syntax)),
Argument(
LiteralExpression(
SyntaxKind.TrueLiteralExpression))
@@ -121,7 +121,7 @@ namespace Microsoft.Interop
// leak the handle if we failed to create the handle.
yield return LocalDeclarationStatement(
VariableDeclaration(
- info.ManagedType.AsTypeSyntax(),
+ info.ManagedType.Syntax,
SingletonSeparatedList(
VariableDeclarator(newHandleObjectIdentifier)
.WithInitializer(EqualsValueClause(safeHandleCreationExpression)))));
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/MarshallingAttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/MarshallingAttributeInfo.cs
index d69f2ea1a58..96988216005 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/MarshallingAttributeInfo.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/MarshallingAttributeInfo.cs
@@ -67,14 +67,15 @@ namespace Microsoft.Interop
internal sealed record BlittableTypeAttributeInfo : MarshallingInfo;
[Flags]
- internal enum SupportedMarshallingMethods
+ internal enum CustomMarshallingFeatures
{
None = 0,
ManagedToNative = 0x1,
NativeToManaged = 0x2,
ManagedToNativeStackalloc = 0x4,
- Pinning = 0x8,
- All = -1
+ ManagedTypePinning = 0x8,
+ NativeTypePinning = 0x10,
+ FreeNativeResources = 0x20,
}
internal abstract record CountInfo;
@@ -106,10 +107,9 @@ namespace Microsoft.Interop
/// User-applied System.Runtime.InteropServices.NativeMarshallingAttribute
/// </summary>
internal record NativeMarshallingAttributeInfo(
- ITypeSymbol NativeMarshallingType,
- ITypeSymbol? ValuePropertyType,
- SupportedMarshallingMethods MarshallingMethods,
- bool NativeTypePinnable,
+ ManagedTypeInfo NativeMarshallingType,
+ ManagedTypeInfo? ValuePropertyType,
+ CustomMarshallingFeatures MarshallingFeatures,
bool UseDefaultMarshalling) : MarshallingInfo;
/// <summary>
@@ -122,24 +122,22 @@ namespace Microsoft.Interop
/// <summary>
/// The type of the element is a SafeHandle-derived type with no marshalling attributes.
/// </summary>
- internal sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor) : MarshallingInfo;
+ internal sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor, bool IsAbstract) : MarshallingInfo;
/// <summary>
/// User-applied System.Runtime.InteropServices.NativeMarshallingAttribute
/// with a contiguous collection marshaller
internal sealed record NativeContiguousCollectionMarshallingInfo(
- ITypeSymbol NativeMarshallingType,
- ITypeSymbol? ValuePropertyType,
- SupportedMarshallingMethods MarshallingMethods,
- bool NativeTypePinnable,
+ ManagedTypeInfo NativeMarshallingType,
+ ManagedTypeInfo? ValuePropertyType,
+ CustomMarshallingFeatures MarshallingFeatures,
bool UseDefaultMarshalling,
CountInfo ElementCountInfo,
- ITypeSymbol ElementType,
+ ManagedTypeInfo ElementType,
MarshallingInfo ElementMarshallingInfo) : NativeMarshallingAttributeInfo(
NativeMarshallingType,
ValuePropertyType,
- MarshallingMethods,
- NativeTypePinnable,
+ MarshallingFeatures,
UseDefaultMarshalling
);
@@ -407,8 +405,8 @@ namespace Microsoft.Interop
{
if (elementName == CountElementCountInfo.ReturnValueElementName)
{
- return TypePositionInfo.CreateForType(
- method.ReturnType,
+ return new TypePositionInfo(
+ ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType),
ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes(), inspectedElements)) with
{
ManagedIndex = TypePositionInfo.ReturnIndex
@@ -539,14 +537,15 @@ namespace Microsoft.Interop
return NoMarshallingInfo.Instance;
}
+ ITypeSymbol? valuePropertyType = ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type;
+
return new NativeContiguousCollectionMarshallingInfo(
- NativeMarshallingType: arrayMarshaller,
- ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type,
- MarshallingMethods: ~SupportedMarshallingMethods.Pinning,
- NativeTypePinnable: true,
+ NativeMarshallingType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller),
+ ValuePropertyType: valuePropertyType is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(valuePropertyType) : null,
+ MarshallingFeatures: ~CustomMarshallingFeatures.ManagedTypePinning,
UseDefaultMarshalling: true,
ElementCountInfo: arraySizeInfo,
- ElementType: elementType,
+ ElementType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(elementType),
ElementMarshallingInfo: elementMarshallingInfo);
}
@@ -560,11 +559,11 @@ namespace Microsoft.Interop
ImmutableHashSet<string> inspectedElements,
ref int maxIndirectionLevelUsed)
{
- SupportedMarshallingMethods methods = SupportedMarshallingMethods.None;
+ CustomMarshallingFeatures features = CustomMarshallingFeatures.None;
if (!isMarshalUsingAttribute && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null)
{
- methods |= SupportedMarshallingMethods.Pinning;
+ features |= CustomMarshallingFeatures.ManagedTypePinning;
}
ITypeSymbol spanOfByte = _compilation.GetTypeByMetadataName(TypeNames.System_Span_Metadata)!.Construct(_compilation.GetSpecialType(SpecialType.System_Byte));
@@ -611,12 +610,12 @@ namespace Microsoft.Interop
{
if (ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, marshallingVariant) && (valueProperty is null or { GetMethod: not null }))
{
- methods |= SupportedMarshallingMethods.ManagedToNative;
+ features |= CustomMarshallingFeatures.ManagedToNative;
}
else if (ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, spanOfByte, marshallingVariant)
&& (valueProperty is null or { GetMethod: not null }))
{
- methods |= SupportedMarshallingMethods.ManagedToNativeStackalloc;
+ features |= CustomMarshallingFeatures.ManagedToNativeStackalloc;
}
else if (ctor.Parameters.Length == 1 && ctor.Parameters[0].Type.SpecialType == SpecialType.System_Int32)
{
@@ -631,10 +630,10 @@ namespace Microsoft.Interop
&& ManualTypeMarshallingHelper.HasToManagedMethod(nativeType, type)
&& (valueProperty is null or { SetMethod: not null }))
{
- methods |= SupportedMarshallingMethods.NativeToManaged;
+ features |= CustomMarshallingFeatures.NativeToManaged;
}
- if (methods == SupportedMarshallingMethods.None)
+ if (features == CustomMarshallingFeatures.None)
{
_diagnostics.ReportInvalidMarshallingAttributeInfo(
attrData,
@@ -645,6 +644,16 @@ namespace Microsoft.Interop
return NoMarshallingInfo.Instance;
}
+ if (ManualTypeMarshallingHelper.HasFreeNativeMethod(nativeType))
+ {
+ features |= CustomMarshallingFeatures.FreeNativeResources;
+ }
+
+ if (ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null)
+ {
+ features |= CustomMarshallingFeatures.NativeTypePinning;
+ }
+
if (isContiguousCollectionMarshaller)
{
if (!ManualTypeMarshallingHelper.HasNativeValueStorageProperty(nativeType, spanOfByte))
@@ -660,21 +669,19 @@ namespace Microsoft.Interop
}
return new NativeContiguousCollectionMarshallingInfo(
- nativeType,
- valueProperty?.Type,
- methods,
- NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null,
+ ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeType),
+ valueProperty is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(valueProperty.Type) : null,
+ features,
UseDefaultMarshalling: !isMarshalUsingAttribute,
parsedCountInfo,
- elementType,
+ ManagedTypeInfo.CreateTypeInfoForTypeSymbol(elementType),
GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionLevelUsed));
}
return new NativeMarshallingAttributeInfo(
- nativeType,
- valueProperty?.Type,
- methods,
- NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null,
+ ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeType),
+ valueProperty is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(valueProperty.Type) : null,
+ features,
UseDefaultMarshalling: !isMarshalUsingAttribute);
}
@@ -705,7 +712,7 @@ namespace Microsoft.Interop
}
}
}
- marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor);
+ marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor, type.IsAbstract);
return true;
}
@@ -729,14 +736,15 @@ namespace Microsoft.Interop
return false;
}
+ ITypeSymbol? valuePropertyType = ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type;
+
marshallingInfo = new NativeContiguousCollectionMarshallingInfo(
- NativeMarshallingType: arrayMarshaller,
- ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type,
- MarshallingMethods: ~SupportedMarshallingMethods.Pinning,
- NativeTypePinnable: true,
+ NativeMarshallingType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller),
+ ValuePropertyType: valuePropertyType is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(valuePropertyType) : null,
+ MarshallingFeatures: ~CustomMarshallingFeatures.ManagedTypePinning,
UseDefaultMarshalling: true,
ElementCountInfo: parsedCountInfo,
- ElementType: elementType,
+ ElementType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(elementType),
ElementMarshallingInfo: GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionLevelUsed));
return true;
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeContext.cs
index fe488c60881..1fcacdb07a7 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeContext.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeContext.cs
@@ -61,7 +61,7 @@ namespace Microsoft.Interop
GuaranteedUnmarshal
}
- public Stage CurrentStage { get; protected set; } = Stage.Invalid;
+ public Stage CurrentStage { get; set; } = Stage.Invalid;
/// <summary>
/// The stub emits code that runs in a single stack frame and the frame spans over the native context.
@@ -88,7 +88,7 @@ namespace Microsoft.Interop
/// </summary>
public StubCodeContext? ParentContext { get; protected set; }
- protected const string GeneratedNativeIdentifierSuffix = "_gen_native";
+ public const string GeneratedNativeIdentifierSuffix = "_gen_native";
/// <summary>
/// Get managed and native instance identifiers for the <paramref name="info"/>
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs
index 5679a023509..2732414adc4 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Runtime.InteropServices;
@@ -8,11 +9,14 @@ using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+using static Microsoft.Interop.StubCodeContext;
namespace Microsoft.Interop
{
internal sealed class StubCodeGenerator : StubCodeContext
{
+ private record struct BoundGenerator(TypePositionInfo TypeInfo, IMarshallingGenerator Generator);
+
public override bool SingleFrameSpansNativeContext => true;
public override bool AdditionalTemporaryStateLivesAcrossStages => true;
@@ -26,7 +30,7 @@ namespace Microsoft.Interop
/// Identifier for native return value
/// </summary>
/// <remarks>Same as the managed identifier by default</remarks>
- public string ReturnNativeIdentifier { get; private set; } = ReturnIdentifier;
+ public string ReturnNativeIdentifier { get; } = ReturnIdentifier;
private const string InvokeReturnIdentifier = "__invokeRetVal";
private const string LastErrorIdentifier = "__lastError";
@@ -35,40 +39,62 @@ namespace Microsoft.Interop
// Error code representing success. This maps to S_OK for Windows HRESULT semantics and 0 for POSIX errno semantics.
private const int SuccessErrorCode = 0;
- private readonly GeneratorDiagnostics diagnostics;
private readonly AnalyzerConfigOptions options;
- private readonly IMethodSymbol stubMethod;
- private readonly DllImportStub.GeneratedDllImportData dllImportData;
- private readonly IEnumerable<TypePositionInfo> paramsTypeInfo;
- private readonly List<(TypePositionInfo TypeInfo, IMarshallingGenerator Generator)> paramMarshallers;
- private readonly (TypePositionInfo TypeInfo, IMarshallingGenerator Generator) retMarshaller;
- private readonly List<(TypePositionInfo TypeInfo, IMarshallingGenerator Generator)> sortedMarshallers;
+ private readonly GeneratedDllImportData dllImportData;
+ private readonly List<BoundGenerator> paramMarshallers;
+ private readonly BoundGenerator retMarshaller;
+ private readonly List<BoundGenerator> sortedMarshallers;
+ private readonly bool stubReturnsVoid;
public StubCodeGenerator(
- IMethodSymbol stubMethod,
- DllImportStub.GeneratedDllImportData dllImportData,
- IEnumerable<TypePositionInfo> paramsTypeInfo,
- TypePositionInfo retTypeInfo,
- GeneratorDiagnostics generatorDiagnostics,
- AnalyzerConfigOptions options)
+ GeneratedDllImportData dllImportData,
+ IEnumerable<TypePositionInfo> argTypes,
+ AnalyzerConfigOptions options,
+ Action<TypePositionInfo, MarshallingNotSupportedException> marshallingNotSupportedCallback)
{
- Debug.Assert(retTypeInfo.IsNativeReturnPosition);
-
- this.stubMethod = stubMethod;
this.dllImportData = dllImportData;
- this.paramsTypeInfo = paramsTypeInfo.ToList();
- this.diagnostics = generatorDiagnostics;
this.options = options;
- // Get marshallers for parameters
- this.paramMarshallers = paramsTypeInfo.Select(p => CreateGenerator(p)).ToList();
+ List<BoundGenerator> allMarshallers = new();
+ List<BoundGenerator> paramMarshallers = new();
+ bool foundNativeRetMarshaller = false;
+ bool foundManagedRetMarshaller = false;
+ BoundGenerator nativeRetMarshaller = new(new TypePositionInfo(SpecialTypeInfo.Void, NoMarshallingInfo.Instance), new Forwarder());
+ BoundGenerator managedRetMarshaller = new(new TypePositionInfo(SpecialTypeInfo.Void, NoMarshallingInfo.Instance), new Forwarder());
+
+ foreach (var argType in argTypes)
+ {
+ BoundGenerator generator = CreateGenerator(argType);
+ allMarshallers.Add(generator);
+ if (argType.IsManagedReturnPosition)
+ {
+ Debug.Assert(!foundManagedRetMarshaller);
+ managedRetMarshaller = generator;
+ foundManagedRetMarshaller = true;
+ }
+ if (argType.IsNativeReturnPosition)
+ {
+ Debug.Assert(!foundNativeRetMarshaller);
+ nativeRetMarshaller = generator;
+ foundNativeRetMarshaller = true;
+ }
+ if (!argType.IsManagedReturnPosition && !argType.IsNativeReturnPosition)
+ {
+ paramMarshallers.Add(generator);
+ }
+ }
- // Get marshaller for return
- this.retMarshaller = CreateGenerator(retTypeInfo);
+ this.stubReturnsVoid = managedRetMarshaller.TypeInfo.ManagedType == SpecialTypeInfo.Void;
+ if (!managedRetMarshaller.TypeInfo.IsNativeReturnPosition && !this.stubReturnsVoid)
+ {
+ // If the managed ret marshaller isn't the native ret marshaller, then the managed ret marshaller
+ // is a parameter.
+ paramMarshallers.Add(managedRetMarshaller);
+ }
- List<(TypePositionInfo TypeInfo, IMarshallingGenerator Generator)> allMarshallers = new(this.paramMarshallers);
- allMarshallers.Add(retMarshaller);
+ this.retMarshaller = nativeRetMarshaller;
+ this.paramMarshallers = paramMarshallers;
// We are doing a topological sort of our marshallers to ensure that each parameter/return value's
// dependencies are unmarshalled before their dependents. This comes up in the case of contiguous
@@ -98,17 +124,10 @@ namespace Microsoft.Interop
static m => GetInfoDependencies(m.TypeInfo))
.ToList();
- (TypePositionInfo info, IMarshallingGenerator gen) CreateGenerator(TypePositionInfo p)
+ if (managedRetMarshaller.Generator.UsesNativeIdentifier(managedRetMarshaller.TypeInfo, this))
{
- try
- {
- return (p, MarshallingGenerators.Create(p, this, options));
- }
- catch (MarshallingNotSupportedException e)
- {
- this.diagnostics.ReportMarshallingNotSupported(this.stubMethod, p, e.NotSupportedDetails);
- return (p, MarshallingGenerators.Forwarder);
- }
+ // Update the native identifier for the return value
+ this.ReturnNativeIdentifier = $"{ReturnIdentifier}{GeneratedNativeIdentifierSuffix}";
}
static IEnumerable<int> GetInfoDependencies(TypePositionInfo info)
@@ -132,6 +151,19 @@ namespace Microsoft.Interop
}
return info.ManagedIndex;
}
+
+ BoundGenerator CreateGenerator(TypePositionInfo p)
+ {
+ try
+ {
+ return new BoundGenerator(p, MarshallingGenerators.Create(p, this, options));
+ }
+ catch (MarshallingNotSupportedException e)
+ {
+ marshallingNotSupportedCallback(p, e);
+ return new BoundGenerator(p, MarshallingGenerators.Forwarder);
+ }
+ }
}
public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
@@ -164,17 +196,11 @@ namespace Microsoft.Interop
}
}
- public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes)
+ public BlockSyntax GenerateBody(string methodName, AttributeListSyntax? forwardedAttributes)
{
- string dllImportName = stubMethod.Name + "__PInvoke__";
+ string dllImportName = methodName + "__PInvoke__";
var setupStatements = new List<StatementSyntax>();
- if (retMarshaller.Generator.UsesNativeIdentifier(retMarshaller.TypeInfo, this))
- {
- // Update the native identifier for the return value
- ReturnNativeIdentifier = $"{ReturnIdentifier}{GeneratedNativeIdentifierSuffix}";
- }
-
foreach (var marshaller in paramMarshallers)
{
TypePositionInfo info = marshaller.TypeInfo;
@@ -197,8 +223,7 @@ namespace Microsoft.Interop
AppendVariableDeclations(setupStatements, info, marshaller.Generator);
}
- bool invokeReturnsVoid = retMarshaller.TypeInfo.ManagedType.SpecialType == SpecialType.System_Void;
- bool stubReturnsVoid = stubMethod.ReturnsVoid;
+ bool invokeReturnsVoid = retMarshaller.TypeInfo.ManagedType == SpecialTypeInfo.Void;
// Stub return is not the same as invoke return
if (!stubReturnsVoid && !retMarshaller.TypeInfo.IsManagedReturnPosition)
@@ -210,11 +235,6 @@ namespace Microsoft.Interop
Debug.Assert(paramMarshallers.Any() && paramMarshallers.Last().TypeInfo.IsManagedReturnPosition, "Expected stub return to be the last parameter for the invoke");
(TypePositionInfo stubRetTypeInfo, IMarshallingGenerator stubRetGenerator) = paramMarshallers.Last();
- if (stubRetGenerator.UsesNativeIdentifier(stubRetTypeInfo, this))
- {
- // Update the native identifier for the return value
- ReturnNativeIdentifier = $"{ReturnIdentifier}{GeneratedNativeIdentifierSuffix}";
- }
// Declare variables for stub return value
AppendVariableDeclations(setupStatements, stubRetTypeInfo, stubRetGenerator);
@@ -303,7 +323,7 @@ namespace Microsoft.Interop
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
.WithAttributeLists(
SingletonList(AttributeList(
- SingletonSeparatedList(CreateDllImportAttributeForTarget(GetTargetDllImportDataFromStubData())))));
+ SingletonSeparatedList(CreateDllImportAttributeForTarget(GetTargetDllImportDataFromStubData(methodName))))));
if (retMarshaller.Generator is IAttributedReturnTypeMarshallingGenerator retGenerator)
{
@@ -313,7 +333,7 @@ namespace Microsoft.Interop
dllImport = dllImport.AddAttributeLists(returnAttribute.WithTarget(AttributeTargetSpecifier(Identifier("return"))));
}
}
-
+
if (forwardedAttributes is not null)
{
dllImport = dllImport.AddAttributeLists(forwardedAttributes);
@@ -334,7 +354,6 @@ namespace Microsoft.Interop
if (!invokeReturnsVoid && (stage is Stage.Setup or Stage.Cleanup))
{
- // Handle setup and unmarshalling for return
var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, this);
statementsToUpdate.AddRange(retStatements);
}
@@ -400,7 +419,6 @@ namespace Microsoft.Interop
}
StatementSyntax invokeStatement;
-
// Assign to return value if necessary
if (invokeReturnsVoid)
{
@@ -442,7 +460,6 @@ namespace Microsoft.Interop
invokeStatement = Block(clearLastError, invokeStatement, getLastError);
}
-
// Nest invocation in fixed statements
if (fixedStatements.Any())
{
@@ -467,13 +484,13 @@ namespace Microsoft.Interop
private void AppendVariableDeclations(List<StatementSyntax> statementsToUpdate, TypePositionInfo info, IMarshallingGenerator generator)
{
- var (managed, native) = GetIdentifiers(info);
+ var (managed, native) = this.GetIdentifiers(info);
// Declare variable for return value
if (info.IsManagedReturnPosition || info.IsNativeReturnPosition)
{
statementsToUpdate.Add(MarshallerHelpers.DeclareWithDefault(
- info.ManagedType.AsTypeSyntax(),
+ info.ManagedType.Syntax,
managed));
}
@@ -486,8 +503,9 @@ namespace Microsoft.Interop
}
}
- private static AttributeSyntax CreateDllImportAttributeForTarget(DllImportStub.GeneratedDllImportData targetDllImportData)
+ private static AttributeSyntax CreateDllImportAttributeForTarget(GeneratedDllImportData targetDllImportData)
{
+ Debug.Assert(targetDllImportData.EntryPoint is not null);
var newAttributeArgs = new List<AttributeArgumentSyntax>
{
AttributeArgument(LiteralExpression(
@@ -496,46 +514,46 @@ namespace Microsoft.Interop
AttributeArgument(
NameEquals(nameof(DllImportAttribute.EntryPoint)),
null,
- CreateStringExpressionSyntax(targetDllImportData.EntryPoint))
+ CreateStringExpressionSyntax(targetDllImportData.EntryPoint!))
};
- if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.BestFitMapping))
+ if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.BestFitMapping))
{
var name = NameEquals(nameof(DllImportAttribute.BestFitMapping));
var value = CreateBoolExpressionSyntax(targetDllImportData.BestFitMapping);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
- if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.CallingConvention))
+ if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.CallingConvention))
{
var name = NameEquals(nameof(DllImportAttribute.CallingConvention));
var value = CreateEnumExpressionSyntax(targetDllImportData.CallingConvention);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
- if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.CharSet))
+ if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.CharSet))
{
var name = NameEquals(nameof(DllImportAttribute.CharSet));
var value = CreateEnumExpressionSyntax(targetDllImportData.CharSet);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
- if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.ExactSpelling))
+ if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.ExactSpelling))
{
var name = NameEquals(nameof(DllImportAttribute.ExactSpelling));
var value = CreateBoolExpressionSyntax(targetDllImportData.ExactSpelling);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
- if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.PreserveSig))
+ if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.PreserveSig))
{
var name = NameEquals(nameof(DllImportAttribute.PreserveSig));
var value = CreateBoolExpressionSyntax(targetDllImportData.PreserveSig);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
- if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.SetLastError))
+ if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.SetLastError))
{
var name = NameEquals(nameof(DllImportAttribute.SetLastError));
var value = CreateBoolExpressionSyntax(targetDllImportData.SetLastError);
newAttributeArgs.Add(AttributeArgument(name, null, value));
}
- if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.ThrowOnUnmappableChar))
+ if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.ThrowOnUnmappableChar))
{
var name = NameEquals(nameof(DllImportAttribute.ThrowOnUnmappableChar));
var value = CreateBoolExpressionSyntax(targetDllImportData.ThrowOnUnmappableChar);
@@ -571,31 +589,22 @@ namespace Microsoft.Interop
}
}
- DllImportStub.GeneratedDllImportData GetTargetDllImportDataFromStubData()
+ GeneratedDllImportData GetTargetDllImportDataFromStubData(string methodName)
{
- DllImportStub.DllImportMember membersToForward = DllImportStub.DllImportMember.All
+ DllImportMember membersToForward = DllImportMember.All
// https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute.preservesig
// If PreserveSig=false (default is true), the P/Invoke stub checks/converts a returned HRESULT to an exception.
- & ~DllImportStub.DllImportMember.PreserveSig
+ & ~DllImportMember.PreserveSig
// https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute.setlasterror
// If SetLastError=true (default is false), the P/Invoke stub gets/caches the last error after invoking the native function.
- & ~DllImportStub.DllImportMember.SetLastError;
+ & ~DllImportMember.SetLastError;
if (options.GenerateForwarders())
{
- membersToForward = DllImportStub.DllImportMember.All;
+ membersToForward = DllImportMember.All;
}
- var targetDllImportData = new DllImportStub.GeneratedDllImportData
+ var targetDllImportData = dllImportData with
{
- CharSet = dllImportData.CharSet,
- BestFitMapping = dllImportData.BestFitMapping,
- CallingConvention = dllImportData.CallingConvention,
- EntryPoint = dllImportData.EntryPoint,
- ModuleName = dllImportData.ModuleName,
- ExactSpelling = dllImportData.ExactSpelling,
- SetLastError = dllImportData.SetLastError,
- PreserveSig = dllImportData.PreserveSig,
- ThrowOnUnmappableChar = dllImportData.ThrowOnUnmappableChar,
IsUserDefined = dllImportData.IsUserDefined & membersToForward
};
@@ -604,9 +613,9 @@ namespace Microsoft.Interop
//
// N.B. The export discovery logic is identical regardless of where
// the name is defined (i.e. method name vs EntryPoint property).
- if (!targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.EntryPoint))
+ if (!targetDllImportData.IsUserDefined.HasFlag(DllImportMember.EntryPoint))
{
- targetDllImportData.EntryPoint = stubMethod.Name;
+ targetDllImportData = targetDllImportData with { EntryPoint = methodName };
}
return targetDllImportData;
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypePositionInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypePositionInfo.cs
index 0f011d9f084..2cbbcb87bec 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypePositionInfo.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypePositionInfo.cs
@@ -40,27 +40,15 @@ namespace Microsoft.Interop
/// <summary>
/// Positional type information involved in unmanaged/managed scenarios.
/// </summary>
- internal sealed record TypePositionInfo
+ internal sealed record TypePositionInfo(ManagedTypeInfo ManagedType, MarshallingInfo MarshallingAttributeInfo)
{
public const int UnsetIndex = int.MinValue;
public const int ReturnIndex = UnsetIndex + 1;
-// We don't need the warnings around not setting the various
-// non-nullable fields/properties on this type in the constructor
-// since we always use a property initializer.
-#pragma warning disable 8618
- private TypePositionInfo()
- {
- this.ManagedIndex = UnsetIndex;
- this.NativeIndex = UnsetIndex;
- }
-#pragma warning restore
-
- public string InstanceIdentifier { get; init; }
- public ITypeSymbol ManagedType { get; init; }
+ public string InstanceIdentifier { get; init; } = string.Empty;
- public RefKind RefKind { get; init; }
- public SyntaxKind RefKindSyntax { get; init; }
+ public RefKind RefKind { get; init; } = RefKind.None;
+ public SyntaxKind RefKindSyntax { get; init; } = SyntaxKind.None;
public bool IsByRef => RefKind != RefKind.None;
@@ -69,40 +57,22 @@ namespace Microsoft.Interop
public bool IsManagedReturnPosition { get => this.ManagedIndex == ReturnIndex; }
public bool IsNativeReturnPosition { get => this.NativeIndex == ReturnIndex; }
- public int ManagedIndex { get; init; }
- public int NativeIndex { get; init; }
-
- public MarshallingInfo MarshallingAttributeInfo { get; init; }
+ public int ManagedIndex { get; init; } = UnsetIndex;
+ public int NativeIndex { get; init; } = UnsetIndex;
public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, MarshallingInfo marshallingInfo, Compilation compilation)
{
- var typeInfo = new TypePositionInfo()
+ var typeInfo = new TypePositionInfo(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(paramSymbol.Type), marshallingInfo)
{
- ManagedType = paramSymbol.Type,
InstanceIdentifier = ParseToken(paramSymbol.Name).IsReservedKeyword() ? $"@{paramSymbol.Name}" : paramSymbol.Name,
RefKind = paramSymbol.RefKind,
RefKindSyntax = RefKindToSyntax(paramSymbol.RefKind),
- MarshallingAttributeInfo = marshallingInfo,
ByValueContentsMarshalKind = GetByValueContentsMarshalKind(paramSymbol.GetAttributes(), compilation)
};
return typeInfo;
}
- public static TypePositionInfo CreateForType(ITypeSymbol type, MarshallingInfo marshallingInfo, string identifier = "")
- {
- var typeInfo = new TypePositionInfo()
- {
- ManagedType = type,
- InstanceIdentifier = identifier,
- RefKind = RefKind.None,
- RefKindSyntax = SyntaxKind.None,
- MarshallingAttributeInfo = marshallingInfo
- };
-
- return typeInfo;
- }
-
private static ByValueContentsMarshalKind GetByValueContentsMarshalKind(IEnumerable<AttributeData> attributes, Compilation compilation)
{
INamedTypeSymbol outAttributeType = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_OutAttribute)!;
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypeSymbolExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypeSymbolExtensions.cs
index a37aa5708ef..e9b3b8261a9 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypeSymbolExtensions.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypeSymbolExtensions.cs
@@ -163,9 +163,9 @@ namespace Microsoft.Interop
return SyntaxFactory.ParseTypeName(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
}
- public static bool IsIntegralType(this ITypeSymbol type)
+ public static bool IsIntegralType(this SpecialType type)
{
- return type.SpecialType switch
+ return type switch
{
SpecialType.System_SByte
or SpecialType.System_Byte
diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/UnreachableException.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/UnreachableException.cs
new file mode 100644
index 00000000000..f33ae8b9564
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/UnreachableException.cs
@@ -0,0 +1,13 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Microsoft.Interop
+{
+ /// <summary>
+ /// An exception that should be thrown on code-paths that are unreachable.
+ /// </summary>
+ internal class UnreachableException : Exception
+ {
+ }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/Ancillary.Interop.csproj b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/Ancillary.Interop.csproj
index 7af1a40c40b..e19d9b4e5d9 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/Ancillary.Interop.csproj
+++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/Ancillary.Interop.csproj
@@ -3,7 +3,6 @@
<PropertyGroup>
<AssemblyName>Microsoft.Interop.Ancillary</AssemblyName>
<TargetFramework>net6.0</TargetFramework>
- <LangVersion>8.0</LangVersion>
<RootNamespace>System.Runtime.InteropServices</RootNamespace>
<Nullable>enable</Nullable>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
diff --git a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/IncrementalGenerationTests.cs b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/IncrementalGenerationTests.cs
new file mode 100644
index 00000000000..e37e6fdd7ab
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/IncrementalGenerationTests.cs
@@ -0,0 +1,203 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.Text;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Xunit;
+using static Microsoft.Interop.DllImportGenerator;
+
+namespace DllImportGenerator.UnitTests
+{
+ public class IncrementalGenerationTests
+ {
+ public const string RequiresIncrementalSyntaxTreeModifySupport = "The GeneratorDriver treats all SyntaxTree replace operations on a Compilation as an Add/Remove operation instead of a Modify operation"
+ + ", so all cached results based on that input are thrown out. As a result, we cannot validate that unrelated changes within the same SyntaxTree do not cause regeneration.";
+
+ [Fact]
+ public async Task AddingNewUnrelatedType_DoesNotRegenerateSource()
+ {
+ string source = CodeSnippets.BasicParametersAndModifiers<int>();
+
+ Compilation comp1 = await TestUtils.CreateCompilation(source);
+
+ Microsoft.Interop.DllImportGenerator generator = new();
+ GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new IIncrementalGenerator[] { generator });
+
+ driver = driver.RunGenerators(comp1);
+
+ generator.IncrementalTracker = new IncrementalityTracker();
+
+ Compilation comp2 = comp1.AddSyntaxTrees(CSharpSyntaxTree.ParseText("struct Foo {}", new CSharpParseOptions(LanguageVersion.Preview)));
+ driver.RunGenerators(comp2);
+
+ Assert.Collection(generator.IncrementalTracker.ExecutedSteps,
+ step =>
+ {
+ Assert.Equal(IncrementalityTracker.StepName.CalculateStubInformation, step.Step);
+ });
+ }
+
+ [Fact(Skip = RequiresIncrementalSyntaxTreeModifySupport)]
+ public async Task AppendingUnrelatedSource_DoesNotRegenerateSource()
+ {
+ string source = $"namespace NS{{{CodeSnippets.BasicParametersAndModifiers<int>()}}}";
+
+ SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview));
+
+ Compilation comp1 = await TestUtils.CreateCompilation(new[] { syntaxTree });
+
+ Microsoft.Interop.DllImportGenerator generator = new();
+ GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new[] { generator });
+
+ driver = driver.RunGenerators(comp1);
+
+ generator.IncrementalTracker = new IncrementalityTracker();
+
+ SyntaxTree newTree = syntaxTree.WithRootAndOptions(syntaxTree.GetCompilationUnitRoot().AddMembers(SyntaxFactory.ParseMemberDeclaration("struct Foo {}")!), syntaxTree.Options);
+
+ Compilation comp2 = comp1.ReplaceSyntaxTree(comp1.SyntaxTrees.First(), newTree);
+ driver.RunGenerators(comp2);
+
+ Assert.Collection(generator.IncrementalTracker.ExecutedSteps,
+ step =>
+ {
+ Assert.Equal(IncrementalityTracker.StepName.CalculateStubInformation, step.Step);
+ });
+ }
+
+ [Fact]
+ public async Task AddingFileWithNewGeneratedDllImport_DoesNotRegenerateOriginalMethod()
+ {
+ string source = CodeSnippets.BasicParametersAndModifiers<int>();
+
+ Compilation comp1 = await TestUtils.CreateCompilation(source);
+
+ Microsoft.Interop.DllImportGenerator generator = new();
+ GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new[] { generator });
+
+ driver = driver.RunGenerators(comp1);
+
+ generator.IncrementalTracker = new IncrementalityTracker();
+
+ Compilation comp2 = comp1.AddSyntaxTrees(CSharpSyntaxTree.ParseText(CodeSnippets.BasicParametersAndModifiers<bool>(), new CSharpParseOptions(LanguageVersion.Preview)));
+ driver.RunGenerators(comp2);
+
+ Assert.Equal(2, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.CalculateStubInformation));
+ Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.GenerateSingleStub));
+ Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.NormalizeWhitespace));
+ Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.ConcatenateStubs));
+ Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.OutputSourceFile));
+ }
+
+ [Fact]
+ public async Task ReplacingFileWithNewGeneratedDllImport_DoesNotRegenerateStubsInOtherFiles()
+ {
+ string source = CodeSnippets.BasicParametersAndModifiers<int>();
+
+ Compilation comp1 = await TestUtils.CreateCompilation(new string[] { CodeSnippets.BasicParametersAndModifiers<int>(), CodeSnippets.BasicParametersAndModifiers<bool>() });
+
+ Microsoft.Interop.DllImportGenerator generator = new();
+ GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new[] { generator });
+
+ driver = driver.RunGenerators(comp1);
+
+ generator.IncrementalTracker = new IncrementalityTracker();
+
+ Compilation comp2 = comp1.ReplaceSyntaxTree(comp1.SyntaxTrees.First(), CSharpSyntaxTree.ParseText(CodeSnippets.BasicParametersAndModifiers<ulong>(), new CSharpParseOptions(LanguageVersion.Preview)));
+ driver.RunGenerators(comp2);
+
+ Assert.Equal(2, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.CalculateStubInformation));
+ Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.GenerateSingleStub));
+ Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.NormalizeWhitespace));
+ Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.ConcatenateStubs));
+ Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.OutputSourceFile));
+ }
+
+ [Fact]
+ public async Task ChangingMarshallingStrategy_RegeneratesStub()
+ {
+ string stubSource = CodeSnippets.BasicParametersAndModifiers("CustomType");
+
+ string customTypeImpl1 = "struct CustomType { System.IntPtr handle; }";
+
+ string customTypeImpl2 = "class CustomType : Microsoft.Win32.SafeHandles.SafeHandleZeroOrMinusOneIsInvalid { public CustomType():base(true){} protected override bool ReleaseHandle(){return true;} }";
+
+
+ Compilation comp1 = await TestUtils.CreateCompilation(stubSource);
+
+ SyntaxTree customTypeImpl1Tree = CSharpSyntaxTree.ParseText(customTypeImpl1, new CSharpParseOptions(LanguageVersion.Preview));
+ comp1 = comp1.AddSyntaxTrees(customTypeImpl1Tree);
+
+ Microsoft.Interop.DllImportGenerator generator = new();
+ GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new[] { generator });
+
+ driver = driver.RunGenerators(comp1);
+
+ generator.IncrementalTracker = new IncrementalityTracker();
+
+ Compilation comp2 = comp1.ReplaceSyntaxTree(customTypeImpl1Tree, CSharpSyntaxTree.ParseText(customTypeImpl2, new CSharpParseOptions(LanguageVersion.Preview)));
+ driver.RunGenerators(comp2);
+
+ Assert.Collection(generator.IncrementalTracker.ExecutedSteps,
+ step =>
+ {
+ Assert.Equal(IncrementalityTracker.StepName.CalculateStubInformation, step.Step);
+ },
+ step =>
+ {
+ Assert.Equal(IncrementalityTracker.StepName.GenerateSingleStub, step.Step);
+ },
+ step =>
+ {
+ Assert.Equal(IncrementalityTracker.StepName.NormalizeWhitespace, step.Step);
+ },
+ step =>
+ {
+ Assert.Equal(IncrementalityTracker.StepName.ConcatenateStubs, step.Step);
+ },
+ step =>
+ {
+ Assert.Equal(IncrementalityTracker.StepName.OutputSourceFile, step.Step);
+ });
+ }
+
+ [Fact(Skip = RequiresIncrementalSyntaxTreeModifySupport)]
+ public async Task ChangingMarshallingAttributes_SameStrategy_DoesNotRegenerate()
+ {
+ string source = CodeSnippets.BasicParametersAndModifiers<bool>();
+
+ SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview));
+
+ Compilation comp1 = await TestUtils.CreateCompilation(new[] { syntaxTree });
+
+ Microsoft.Interop.DllImportGenerator generator = new();
+ GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new[] { generator });
+
+ driver = driver.RunGenerators(comp1);
+
+ generator.IncrementalTracker = new IncrementalityTracker();
+
+ SyntaxTree newTree = syntaxTree.WithRootAndOptions(
+ syntaxTree.GetCompilationUnitRoot().AddMembers(
+ SyntaxFactory.ParseMemberDeclaration(
+ CodeSnippets.MarshalAsParametersAndModifiers<bool>(System.Runtime.InteropServices.UnmanagedType.Bool))!),
+ syntaxTree.Options);
+
+ Compilation comp2 = comp1.ReplaceSyntaxTree(comp1.SyntaxTrees.First(), newTree);
+ driver.RunGenerators(comp2);
+
+ Assert.Collection(generator.IncrementalTracker.ExecutedSteps,
+ step =>
+ {
+ Assert.Equal(IncrementalityTracker.StepName.CalculateStubInformation, step.Step);
+ },
+ step =>
+ {
+ Assert.Equal(IncrementalityTracker.StepName.GenerateSingleStub, step.Step);
+ });
+ }
+ }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/TestUtils.cs b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/TestUtils.cs
index fd49c82ccf9..db7b0f3fecc 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/TestUtils.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/TestUtils.cs
@@ -45,14 +45,26 @@ namespace DllImportGenerator.UnitTests
/// <param name="outputKind">Output type</param>
/// <param name="allowUnsafe">Whether or not use of the unsafe keyword should be allowed</param>
/// <returns>The resulting compilation</returns>
- public static async Task<Compilation> CreateCompilation(string source, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable<string>? preprocessorSymbols = null)
+ public static Task<Compilation> CreateCompilation(string source, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable<string>? preprocessorSymbols = null)
{
- var (mdRefs, ancillary) = GetReferenceAssemblies();
+ return CreateCompilation(new[] { source }, outputKind, allowUnsafe, preprocessorSymbols);
+ }
- return CSharpCompilation.Create("compilation",
- new[] { CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview, preprocessorSymbols: preprocessorSymbols)) },
- (await mdRefs.ResolveAsync(LanguageNames.CSharp, CancellationToken.None)).Add(ancillary),
- new CSharpCompilationOptions(outputKind, allowUnsafe: allowUnsafe));
+ /// <summary>
+ /// Create a compilation given sources
+ /// </summary>
+ /// <param name="sources">Sources to compile</param>
+ /// <param name="outputKind">Output type</param>
+ /// <param name="allowUnsafe">Whether or not use of the unsafe keyword should be allowed</param>
+ /// <returns>The resulting compilation</returns>
+ public static Task<Compilation> CreateCompilation(string[] sources, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable<string>? preprocessorSymbols = null)
+ {
+ return CreateCompilation(
+ sources.Select(source =>
+ CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview, preprocessorSymbols: preprocessorSymbols))).ToArray(),
+ outputKind,
+ allowUnsafe,
+ preprocessorSymbols);
}
/// <summary>
@@ -62,13 +74,12 @@ namespace DllImportGenerator.UnitTests
/// <param name="outputKind">Output type</param>
/// <param name="allowUnsafe">Whether or not use of the unsafe keyword should be allowed</param>
/// <returns>The resulting compilation</returns>
- public static async Task<Compilation> CreateCompilation(string[] sources, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable<string>? preprocessorSymbols = null)
+ public static async Task<Compilation> CreateCompilation(SyntaxTree[] sources, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable<string>? preprocessorSymbols = null)
{
var (mdRefs, ancillary) = GetReferenceAssemblies();
return CSharpCompilation.Create("compilation",
- sources.Select(source =>
- CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview, preprocessorSymbols: preprocessorSymbols))).ToArray(),
+ sources,
(await mdRefs.ResolveAsync(LanguageNames.CSharp, CancellationToken.None)).Add(ancillary),
new CSharpCompilationOptions(outputKind, allowUnsafe: allowUnsafe));
}
@@ -81,10 +92,23 @@ namespace DllImportGenerator.UnitTests
/// <param name="outputKind">Output type</param>
/// <param name="allowUnsafe">Whether or not use of the unsafe keyword should be allowed</param>
/// <returns>The resulting compilation</returns>
- public static async Task<Compilation> CreateCompilationWithReferenceAssemblies(string source, ReferenceAssemblies referenceAssemblies, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true)
+ public static Task<Compilation> CreateCompilationWithReferenceAssemblies(string source, ReferenceAssemblies referenceAssemblies, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true)
+ {
+ return CreateCompilationWithReferenceAssemblies(new[] { CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview)) }, referenceAssemblies, outputKind, allowUnsafe);
+ }
+
+ /// <summary>
+ /// Create a compilation given source and reference assemblies
+ /// </summary>
+ /// <param name="source">Source to compile</param>
+ /// <param name="referenceAssemblies">Reference assemblies to include</param>
+ /// <param name="outputKind">Output type</param>
+ /// <param name="allowUnsafe">Whether or not use of the unsafe keyword should be allowed</param>
+ /// <returns>The resulting compilation</returns>
+ public static async Task<Compilation> CreateCompilationWithReferenceAssemblies(SyntaxTree[] sources, ReferenceAssemblies referenceAssemblies, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true)
{
return CSharpCompilation.Create("compilation",
- new[] { CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview)) },
+ sources,
(await referenceAssemblies.ResolveAsync(LanguageNames.CSharp, CancellationToken.None)),
new CSharpCompilationOptions(outputKind, allowUnsafe: allowUnsafe));
}
@@ -96,7 +120,7 @@ namespace DllImportGenerator.UnitTests
"net6.0",
new PackageIdentity(
"Microsoft.NETCore.App.Ref",
- "6.0.0-preview.6.21317.4"),
+ "6.0.0-preview.7.21377.19"),
Path.Combine("ref", "net6.0"))
.WithNuGetConfigFilePath(Path.Combine(Path.GetDirectoryName(typeof(TestUtils).Assembly.Location)!, "NuGet.config"));
@@ -114,7 +138,7 @@ namespace DllImportGenerator.UnitTests
/// <param name="diagnostics">Resulting diagnostics</param>
/// <param name="generators">Source generator instances</param>
/// <returns>The resulting compilation</returns>
- public static Compilation RunGenerators(Compilation comp, out ImmutableArray<Diagnostic> diagnostics, params ISourceGenerator[] generators)
+ public static Compilation RunGenerators(Compilation comp, out ImmutableArray<Diagnostic> diagnostics, params IIncrementalGenerator[] generators)
{
CreateDriver(comp, null, generators).RunGeneratorsAndUpdateCompilation(comp, out var d, out diagnostics);
return d;
@@ -127,15 +151,15 @@ namespace DllImportGenerator.UnitTests
/// <param name="diagnostics">Resulting diagnostics</param>
/// <param name="generators">Source generator instances</param>
/// <returns>The resulting compilation</returns>
- public static Compilation RunGenerators(Compilation comp, AnalyzerConfigOptionsProvider options, out ImmutableArray<Diagnostic> diagnostics, params ISourceGenerator[] generators)
+ public static Compilation RunGenerators(Compilation comp, AnalyzerConfigOptionsProvider options, out ImmutableArray<Diagnostic> diagnostics, params IIncrementalGenerator[] generators)
{
CreateDriver(comp, options, generators).RunGeneratorsAndUpdateCompilation(comp, out var d, out diagnostics);
return d;
}
- private static GeneratorDriver CreateDriver(Compilation c, AnalyzerConfigOptionsProvider? options, ISourceGenerator[] generators)
+ public static GeneratorDriver CreateDriver(Compilation c, AnalyzerConfigOptionsProvider? options, IIncrementalGenerator[] generators)
=> CSharpGeneratorDriver.Create(
- ImmutableArray.Create(generators),
+ ImmutableArray.Create(generators.Select(gen => gen.AsSourceGenerator()).ToArray()),
parseOptions: (CSharpParseOptions)c.SyntaxTrees.First().Options,
optionsProvider: options);
}