Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/dotnet/runtime.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLevi Broderick <GrabYourPitchforks@users.noreply.github.com>2020-07-15 09:12:10 +0300
committerGitHub <noreply@github.com>2020-07-15 09:12:10 +0300
commit53976d38b1bd6917b8fa4d1dd4f009728ece3adb (patch)
treed1ffc5e72f6e48969b5eccfaae9620a0fd10b030
parentfe9e53e14694e281818945c28ea7468053248fb7 (diff)
[release/5.0-preview7] Disallow unrestricted polymorphic deserialization in DataSet (#39314)v5.0.0-preview.7.20364.11
Fixes CVE-2020-1147 https://portal.msrc.microsoft.com/en-us/security-guidance/advisory/CVE-2020-1147 See also https://go.microsoft.com/fwlink/?linkid=2132227.
-rw-r--r--src/libraries/System.Data.Common/src/Resources/Strings.resx1
-rw-r--r--src/libraries/System.Data.Common/src/System.Data.Common.csproj25
-rw-r--r--src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs3
-rw-r--r--src/libraries/System.Data.Common/src/System/Data/DataColumn.cs1
-rw-r--r--src/libraries/System.Data.Common/src/System/Data/DataException.cs1
-rw-r--r--src/libraries/System.Data.Common/src/System/Data/DataSet.cs6
-rw-r--r--src/libraries/System.Data.Common/src/System/Data/DataTable.cs6
-rw-r--r--src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs24
-rw-r--r--src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs18
-rw-r--r--src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs305
-rw-r--r--src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj1
-rw-r--r--src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs446
12 files changed, 824 insertions, 13 deletions
diff --git a/src/libraries/System.Data.Common/src/Resources/Strings.resx b/src/libraries/System.Data.Common/src/Resources/Strings.resx
index 4535a6944bf..671bf490bc1 100644
--- a/src/libraries/System.Data.Common/src/Resources/Strings.resx
+++ b/src/libraries/System.Data.Common/src/Resources/Strings.resx
@@ -165,6 +165,7 @@
<data name="Data_ArgumentOutOfRange" xml:space="preserve"><value>'{0}' argument is out of range.</value></data>
<data name="Data_ArgumentNull" xml:space="preserve"><value>'{0}' argument cannot be null.</value></data>
<data name="Data_ArgumentContainsNull" xml:space="preserve"><value>'{0}' argument contains null value.</value></data>
+ <data name="Data_TypeNotAllowed" xml:space="preserve"><value>Type '{0}' is not allowed here. See https://go.microsoft.com/fwlink/?linkid=2132227 for more details.</value></data>
<data name="DataColumns_OutOfRange" xml:space="preserve"><value>Cannot find column {0}.</value></data>
<data name="DataColumns_Add1" xml:space="preserve"><value>Column '{0}' already belongs to this DataTable.</value></data>
<data name="DataColumns_Add2" xml:space="preserve"><value>Column '{0}' already belongs to another DataTable.</value></data>
diff --git a/src/libraries/System.Data.Common/src/System.Data.Common.csproj b/src/libraries/System.Data.Common/src/System.Data.Common.csproj
index 20091047d6b..7beb59ac3a8 100644
--- a/src/libraries/System.Data.Common/src/System.Data.Common.csproj
+++ b/src/libraries/System.Data.Common/src/System.Data.Common.csproj
@@ -2,7 +2,7 @@
<PropertyGroup>
<AssemblyName>System.Data.Common</AssemblyName>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
- <TargetFrameworks>$(NetCoreAppCurrent)</TargetFrameworks>
+ <TargetFrameworks>$(NetCoreAppCurrent)-Windows_NT;$(NetCoreAppCurrent)-Unix</TargetFrameworks>
</PropertyGroup>
<ItemGroup>
<Compile Include="System.Data.Common.TypeForwards.cs" />
@@ -123,6 +123,10 @@
<Compile Include="System\Data\KeyRestrictionBehavior.cs" />
<Compile Include="System\Data\LinqDataView.cs" />
<Compile Include="System\Data\LoadOption.cs" />
+ <Compile Include="System\Data\LocalAppContextSwitches.cs" />
+ <Compile Include="$(CommonPath)System\LocalAppContextSwitches.Common.cs">
+ <Link>Common\System\LocalAppContextSwitches.Common.cs</Link>
+ </Compile>
<Compile Include="System\Data\MappingType.cs" />
<Compile Include="System\Data\MergeFailedEvent.cs" />
<Compile Include="System\Data\MergeFailedEventHandler.cs" />
@@ -156,6 +160,7 @@
<Compile Include="System\Data\StrongTypingException.cs" />
<Compile Include="System\Data\TypedTableBase.cs" />
<Compile Include="System\Data\TypedTableBaseExtensions.cs" />
+ <Compile Include="System\Data\TypeLimiter.cs" />
<Compile Include="System\Data\UniqueConstraint.cs" />
<Compile Include="System\Data\UpdateRowSource.cs" />
<Compile Include="System\Data\Common\UInt64Storage.cs" />
@@ -295,25 +300,23 @@
<Compile Include="System\Data\ProviderBase\SchemaMapping.cs" />
</ItemGroup>
<ItemGroup>
- <Reference Include="System.Collections" />
+ <ProjectReference Include="$(CoreLibProject)" />
+ <ProjectReference Include="..\..\System.Collections\src\System.Collections.csproj" />
+ <ProjectReference Include="..\..\System.Collections.NonGeneric\src\System.Collections.NonGeneric.csproj" />
+ <ProjectReference Include="..\..\System.ComponentModel.TypeConverter\src\System.ComponentModel.TypeConverter.csproj" />
+ <ProjectReference Include="..\..\System.Runtime\src\System.Runtime.csproj" />
+ <ProjectReference Include="..\..\System.Runtime.Extensions\src\System.Runtime.Extensions.csproj" />
+ <ProjectReference Include="..\..\System.Private.Uri\src\System.Private.Uri.csproj" />
<Reference Include="System.Collections.Concurrent" />
- <Reference Include="System.Collections.NonGeneric" />
<Reference Include="System.ComponentModel" />
<Reference Include="System.ComponentModel.Primitives" />
- <Reference Include="System.ComponentModel.TypeConverter" />
- <Reference Include="System.Diagnostics.Tracing" />
+ <Reference Include="System.Drawing.Primitives" />
<Reference Include="System.Linq" />
<Reference Include="System.Linq.Expressions" />
- <Reference Include="System.Memory" />
<Reference Include="System.ObjectModel" />
- <Reference Include="System.Runtime" />
- <Reference Include="System.Runtime.Extensions" />
<Reference Include="System.Runtime.Numerics" />
<Reference Include="System.Runtime.Serialization.Formatters" />
- <Reference Include="System.Text.Encoding.Extensions" />
<Reference Include="System.Text.RegularExpressions" />
- <Reference Include="System.Threading" />
- <Reference Include="System.Threading.Thread" />
<Reference Include="System.Transactions.Local" />
<Reference Include="System.Xml.ReaderWriter" />
<Reference Include="System.Xml.XmlSerializer" />
diff --git a/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs b/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs
index ffe4960cce7..dc3fe053356 100644
--- a/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs
@@ -406,6 +406,9 @@ namespace System.Data.Common
if (type == typeof(object))
throw ExceptionBuilder.CanNotDeserializeObjectType();
+
+ TypeLimiter.EnsureTypeIsAllowed(type);
+
if (!isBaseCLRType)
{
retValue = System.Activator.CreateInstance(type, true);
diff --git a/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs b/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs
index d44ec1e0b88..560f6ecd7e7 100644
--- a/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs
@@ -143,6 +143,7 @@ namespace System.Data
private void UpdateColumnType(Type type, StorageType typeCode)
{
+ TypeLimiter.EnsureTypeIsAllowed(type);
_dataType = type;
_storageType = typeCode;
if (StorageType.DateTime != typeCode)
diff --git a/src/libraries/System.Data.Common/src/System/Data/DataException.cs b/src/libraries/System.Data.Common/src/System/Data/DataException.cs
index 7aa84800747..aede6f282f1 100644
--- a/src/libraries/System.Data.Common/src/System/Data/DataException.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/DataException.cs
@@ -350,6 +350,7 @@ namespace System.Data
public static Exception ArgumentOutOfRange(string paramName) => _ArgumentOutOfRange(paramName, SR.Format(SR.Data_ArgumentOutOfRange, paramName));
public static Exception BadObjectPropertyAccess(string error) => _InvalidOperation(SR.Format(SR.DataConstraint_BadObjectPropertyAccess, error));
public static Exception ArgumentContainsNull(string paramName) => _Argument(paramName, SR.Format(SR.Data_ArgumentContainsNull, paramName));
+ public static Exception TypeNotAllowed(Type type) => _InvalidOperation(SR.Format(SR.Data_TypeNotAllowed, type.AssemblyQualifiedName));
//
diff --git a/src/libraries/System.Data.Common/src/System/Data/DataSet.cs b/src/libraries/System.Data.Common/src/System/Data/DataSet.cs
index f4d6ad887d5..c09b35203a9 100644
--- a/src/libraries/System.Data.Common/src/System/Data/DataSet.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/DataSet.cs
@@ -1961,9 +1961,11 @@ namespace System.Data
internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving)
{
+ IDisposable restrictedScope = null;
long logScopeId = DataCommonEventSource.Log.EnterScope("<ds.DataSet.ReadXml|INFO> {0}, denyResolving={1}", ObjectID, denyResolving);
try
{
+ restrictedScope = TypeLimiter.EnterRestrictedScope(this);
DataTable.DSRowDiffIdUsageSection rowDiffIdUsage = default;
try
{
@@ -2231,6 +2233,7 @@ namespace System.Data
}
finally
{
+ restrictedScope?.Dispose();
DataCommonEventSource.Log.ExitScope(logScopeId);
}
}
@@ -2467,9 +2470,11 @@ namespace System.Data
internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolving)
{
+ IDisposable restictedScope = null;
long logScopeId = DataCommonEventSource.Log.EnterScope("<ds.DataSet.ReadXml|INFO> {0}, mode={1}, denyResolving={2}", ObjectID, mode, denyResolving);
try
{
+ restictedScope = TypeLimiter.EnterRestrictedScope(this);
XmlReadMode ret = mode;
if (reader == null)
@@ -2711,6 +2716,7 @@ namespace System.Data
}
finally
{
+ restictedScope?.Dispose();
DataCommonEventSource.Log.ExitScope(logScopeId);
}
}
diff --git a/src/libraries/System.Data.Common/src/System/Data/DataTable.cs b/src/libraries/System.Data.Common/src/System/Data/DataTable.cs
index 917908ec2a4..5b2fa268c94 100644
--- a/src/libraries/System.Data.Common/src/System/Data/DataTable.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/DataTable.cs
@@ -5659,9 +5659,11 @@ namespace System.Data
internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving)
{
+ IDisposable restrictedScope = null;
long logScopeId = DataCommonEventSource.Log.EnterScope("<ds.DataTable.ReadXml|INFO> {0}, denyResolving={1}", ObjectID, denyResolving);
try
{
+ restrictedScope = TypeLimiter.EnterRestrictedScope(this);
RowDiffIdUsageSection rowDiffIdUsage = default;
try
{
@@ -5896,15 +5898,18 @@ namespace System.Data
}
finally
{
+ restrictedScope?.Dispose();
DataCommonEventSource.Log.ExitScope(logScopeId);
}
}
internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolving)
{
+ IDisposable restrictedScope = null;
RowDiffIdUsageSection rowDiffIdUsage = default;
try
{
+ restrictedScope = TypeLimiter.EnterRestrictedScope(this);
bool fSchemaFound = false;
bool fDataFound = false;
bool fIsXdr = false;
@@ -6190,6 +6195,7 @@ namespace System.Data
}
finally
{
+ restrictedScope?.Dispose();
// prepare and cleanup rowDiffId hashtable
rowDiffIdUsage.Cleanup();
}
diff --git a/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs b/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs
index 33aa941e303..472cd9be2b7 100644
--- a/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs
@@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.Data.Common;
using System.Data.SqlTypes;
using System.Diagnostics;
+using System.Runtime.Serialization;
namespace System.Data
{
@@ -16,6 +17,7 @@ namespace System.Data
internal int _argumentCount = 0;
internal const int initialCapacity = 1;
internal ExpressionNode[] _arguments;
+ private readonly TypeLimiter _capturedLimiter = null;
private static readonly Function[] s_funcs = new Function[] {
new Function("Abs", FunctionId.Abs, typeof(object), true, false, 1, typeof(object), null, null),
@@ -40,6 +42,12 @@ namespace System.Data
internal FunctionNode(DataTable table, string name) : base(table)
{
+ // Because FunctionNode instances are created eagerly but evaluated lazily,
+ // we need to capture the deserialization scope here. The scope could be
+ // null if no deserialization is in progress.
+
+ _capturedLimiter = TypeLimiter.Capture();
+
_name = name;
for (int i = 0; i < s_funcs.Length; i++)
{
@@ -289,6 +297,11 @@ namespace System.Data
throw ExprException.InvalidType(typeName);
}
+ // ReadXml might not be on the current call stack. So we'll use the TypeLimiter
+ // that was captured when this FunctionNode instance was created.
+
+ TypeLimiter.EnsureTypeIsAllowed(dataType, _capturedLimiter);
+
return dataType;
}
@@ -494,10 +507,17 @@ namespace System.Data
{
return SqlConvert.ChangeType2((decimal)SqlConvert.ChangeType2(argumentValues[0], StorageType.Decimal, typeof(decimal), FormatProvider), mytype, type, FormatProvider);
}
- return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider);
}
- return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider);
+ // The Convert function can be called lazily, outside of a previous Serialization Guard scope.
+ // If there was a type limiter scope on the stack at the time this Convert function was created,
+ // we must manually re-enter the Serialization Guard scope.
+
+ DeserializationToken deserializationToken = (_capturedLimiter != null) ? SerializationInfo.StartDeserialization() : default;
+ using (deserializationToken)
+ {
+ return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider);
+ }
}
return argumentValues[0];
diff --git a/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs b/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs
new file mode 100644
index 00000000000..42afa2b8fed
--- /dev/null
+++ b/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs
@@ -0,0 +1,18 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+
+namespace System
+{
+ internal static partial class LocalAppContextSwitches
+ {
+ private static int s_allowArbitraryTypeInstantiation;
+ public static bool AllowArbitraryTypeInstantiation
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ get => GetCachedSwitchValue("Switch.System.Data.AllowArbitraryDataSetTypeInstantiation", ref s_allowArbitraryTypeInstantiation);
+ }
+ }
+}
diff --git a/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs b/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs
new file mode 100644
index 00000000000..1ff77cb9952
--- /dev/null
+++ b/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs
@@ -0,0 +1,305 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Data.SqlTypes;
+using System.Diagnostics;
+using System.Drawing;
+using System.Linq;
+using System.Numerics;
+using System.Runtime.Serialization;
+
+namespace System.Data
+{
+ internal sealed class TypeLimiter
+ {
+ [ThreadStatic]
+ private static Scope s_activeScope;
+
+ private Scope m_instanceScope;
+
+ private const string AppDomainDataSetDefaultAllowedTypesKey = "System.Data.DataSetDefaultAllowedTypes";
+
+ private TypeLimiter(Scope scope)
+ {
+ Debug.Assert(scope != null);
+ m_instanceScope = scope;
+ }
+
+ private static bool IsTypeLimitingDisabled
+ => LocalAppContextSwitches.AllowArbitraryTypeInstantiation;
+
+ /// <summary>
+ /// Captures the current <see cref="TypeLimiter"/> instance so that future
+ /// type checks can be performed against the allow list that was active during
+ /// the current deserialization scope.
+ /// </summary>
+ /// <remarks>
+ /// Returns null if no limiter is active.
+ /// </remarks>
+ public static TypeLimiter Capture()
+ {
+ Scope activeScope = s_activeScope;
+ return (activeScope != null) ? new TypeLimiter(activeScope) : null;
+ }
+
+ /// <summary>
+ /// Ensures the requested type is allowed by the rules of the active
+ /// deserialization scope. If a captured scope is provided, we'll use
+ /// that previously captured scope rather than the thread-static active
+ /// scope.
+ /// </summary>
+ /// <exception cref="InvalidOperationException">
+ /// If <paramref name="type"/> is not allowed.
+ /// </exception>
+ public static void EnsureTypeIsAllowed(Type type, TypeLimiter capturedLimiter = null)
+ {
+ if (type is null)
+ {
+ return; // nothing to check
+ }
+
+ Scope capturedScope = capturedLimiter?.m_instanceScope ?? s_activeScope;
+ if (capturedScope is null)
+ {
+ return; // we're not in a restricted scope
+ }
+
+ if (capturedScope.IsAllowedType(type))
+ {
+ return; // type was explicitly allowed
+ }
+
+ // We encountered a type that wasn't in the allow list.
+ // Throw an exception to fail the current operation.
+
+ throw ExceptionBuilder.TypeNotAllowed(type);
+ }
+
+ public static IDisposable EnterRestrictedScope(DataSet dataSet)
+ {
+ if (IsTypeLimitingDisabled)
+ {
+ return null; // protections aren't enabled
+ }
+
+ Scope newScope = new Scope(s_activeScope, GetPreviouslyDeclaredDataTypes(dataSet));
+ s_activeScope = newScope;
+ return newScope;
+ }
+
+ public static IDisposable EnterRestrictedScope(DataTable dataTable)
+ {
+ if (IsTypeLimitingDisabled)
+ {
+ return null; // protections aren't enabled
+ }
+
+ Scope newScope = new Scope(s_activeScope, GetPreviouslyDeclaredDataTypes(dataTable));
+ s_activeScope = newScope;
+ return newScope;
+ }
+
+ /// <summary>
+ /// Given a <see cref="DataTable"/>, returns all of the <see cref="DataColumn.DataType"/>
+ /// values declared on the instance.
+ /// </summary>
+ private static IEnumerable<Type> GetPreviouslyDeclaredDataTypes(DataTable dataTable)
+ {
+ return (dataTable != null)
+ ? dataTable.Columns.Cast<DataColumn>().Select(column => column.DataType)
+ : Enumerable.Empty<Type>();
+ }
+
+ /// <summary>
+ /// Given a <see cref="DataSet"/>, returns all of the <see cref="DataColumn.DataType"/>
+ /// values declared on the instance.
+ /// </summary>
+ private static IEnumerable<Type> GetPreviouslyDeclaredDataTypes(DataSet dataSet)
+ {
+ return (dataSet != null)
+ ? dataSet.Tables.Cast<DataTable>().SelectMany(table => GetPreviouslyDeclaredDataTypes(table))
+ : Enumerable.Empty<Type>();
+ }
+
+ private sealed class Scope : IDisposable
+ {
+ /// <summary>
+ /// Types which are always allowed, unconditionally.
+ /// </summary>
+ private static readonly HashSet<Type> s_allowedTypes = new HashSet<Type>()
+ {
+ /* primitives */
+ typeof(bool),
+ typeof(char),
+ typeof(sbyte),
+ typeof(byte),
+ typeof(short),
+ typeof(ushort),
+ typeof(int),
+ typeof(uint),
+ typeof(long),
+ typeof(ulong),
+ typeof(float),
+ typeof(double),
+ typeof(decimal),
+ typeof(DateTime),
+ typeof(DateTimeOffset),
+ typeof(TimeSpan),
+ typeof(string),
+ typeof(Guid),
+ typeof(SqlBinary),
+ typeof(SqlBoolean),
+ typeof(SqlByte),
+ typeof(SqlBytes),
+ typeof(SqlChars),
+ typeof(SqlDateTime),
+ typeof(SqlDecimal),
+ typeof(SqlDouble),
+ typeof(SqlGuid),
+ typeof(SqlInt16),
+ typeof(SqlInt32),
+ typeof(SqlInt64),
+ typeof(SqlMoney),
+ typeof(SqlSingle),
+ typeof(SqlString),
+
+ /* non-primitives, but common */
+ typeof(object),
+ typeof(Type),
+ typeof(BigInteger),
+ typeof(Uri),
+
+ /* frequently used System.Drawing types */
+ typeof(Color),
+ typeof(Point),
+ typeof(PointF),
+ typeof(Rectangle),
+ typeof(RectangleF),
+ typeof(Size),
+ typeof(SizeF),
+ };
+
+ /// <summary>
+ /// Types which are allowed within the context of this scope.
+ /// </summary>
+ private HashSet<Type> m_allowedTypes;
+
+ /// <summary>
+ /// This thread's previous scope.
+ /// </summary>
+ private readonly Scope m_previousScope;
+
+ /// <summary>
+ /// The Serialization Guard token associated with this scope.
+ /// </summary>
+ private readonly DeserializationToken m_deserializationToken;
+
+ internal Scope(Scope previousScope, IEnumerable<Type> allowedTypes)
+ {
+ Debug.Assert(allowedTypes != null);
+
+ m_previousScope = previousScope;
+ m_allowedTypes = new HashSet<Type>(allowedTypes.Where(type => type != null));
+ m_deserializationToken = SerializationInfo.StartDeserialization();
+ }
+
+ public void Dispose()
+ {
+ if (this != s_activeScope)
+ {
+ // Stacks should never be popped out of order.
+ // We want to trap this condition in production.
+ Debug.Fail("Scope was popped out of order.");
+ throw new ObjectDisposedException(GetType().FullName);
+ }
+
+ m_deserializationToken.Dispose(); // it's a readonly struct, but Dispose still works properly
+ s_activeScope = m_previousScope; // could be null
+ }
+
+ public bool IsAllowedType(Type type)
+ {
+ Debug.Assert(type != null);
+
+ // Is the incoming type unconditionally allowed?
+
+ if (IsTypeUnconditionallyAllowed(type))
+ {
+ return true;
+ }
+
+ // The incoming type is allowed if the current scope or any nested inner
+ // scope allowed it.
+
+ for (Scope currentScope = this; currentScope != null; currentScope = currentScope.m_previousScope)
+ {
+ if (currentScope.m_allowedTypes.Contains(type))
+ {
+ return true;
+ }
+ }
+
+ // Did the application programmatically allow this type to be deserialized?
+
+ Type[] appDomainAllowedTypes = (Type[])AppDomain.CurrentDomain.GetData(AppDomainDataSetDefaultAllowedTypesKey);
+ if (appDomainAllowedTypes != null)
+ {
+ for (int i = 0; i < appDomainAllowedTypes.Length; i++)
+ {
+ if (type == appDomainAllowedTypes[i])
+ {
+ return true;
+ }
+ }
+ }
+
+ // All checks failed
+
+ return false;
+ }
+
+ private static bool IsTypeUnconditionallyAllowed(Type type)
+ {
+ TryAgain:
+ Debug.Assert(type != null);
+
+ // Check the list of unconditionally allowed types.
+
+ if (s_allowedTypes.Contains(type))
+ {
+ return true;
+ }
+
+ // Enums are also always allowed, as we optimistically assume the app
+ // developer didn't define a dangerous enum type.
+
+ if (type.IsEnum)
+ {
+ return true;
+ }
+
+ // Allow single-dimensional arrays of any unconditionally allowed type.
+
+ if (type.IsSZArray)
+ {
+ type = type.GetElementType();
+ goto TryAgain;
+ }
+
+ // Allow generic lists of any unconditionally allowed type.
+
+ if (type.IsGenericType && !type.IsGenericTypeDefinition && type.GetGenericTypeDefinition() == typeof(List<>))
+ {
+ type = type.GetGenericArguments()[0];
+ goto TryAgain;
+ }
+
+ // All checks failed.
+
+ return false;
+ }
+ }
+ }
+}
diff --git a/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj b/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj
index 4d0933db75b..9742209ccf9 100644
--- a/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj
+++ b/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj
@@ -112,6 +112,7 @@
<Compile Include="System\Data\VersionNotFoundException.cs" />
<Compile Include="System\Data\XmlDataLoaderTest.cs" />
<Compile Include="System\Data\XmlDataReaderTest.cs" />
+ <Compile Include="System\Data\RestrictedTypeHandlingTests.cs" />
<Compile Include="System\Xml\XmlDataDocumentTests.cs" />
<Compile Include="$(CommonTestPath)System\Diagnostics\Tracing\TestEventListener.cs"
Link="Common\System\Diagnostics\Tracing\TestEventListener.cs" />
diff --git a/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs b/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs
new file mode 100644
index 00000000000..3f76cebc805
--- /dev/null
+++ b/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs
@@ -0,0 +1,446 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Data.SqlTypes;
+using System.Drawing;
+using System.IO;
+using System.Numerics;
+using System.Runtime.Serialization;
+using System.Text;
+using System.Xml;
+using System.Xml.Schema;
+using System.Xml.Serialization;
+using Xunit;
+using Xunit.Sdk;
+
+namespace System.Data.Tests
+{
+ // !! Important !!
+ // These tests manipulate global state, so they cannot be run in parallel with one another.
+ // We rely on xunit's default behavior of not parallelizing unit tests declared on the same
+ // test class: see https://xunit.net/docs/running-tests-in-parallel.html.
+ public class RestrictedTypeHandlingTests
+ {
+ private const string AppDomainDataSetDefaultAllowedTypesKey = "System.Data.DataSetDefaultAllowedTypes";
+
+ private static readonly Type[] _alwaysAllowedTypes = new Type[]
+ {
+ /* primitives */
+ typeof(bool),
+ typeof(char),
+ typeof(sbyte),
+ typeof(byte),
+ typeof(short),
+ typeof(ushort),
+ typeof(int),
+ typeof(uint),
+ typeof(long),
+ typeof(ulong),
+ typeof(float),
+ typeof(double),
+ typeof(decimal),
+ typeof(DateTime),
+ typeof(DateTimeOffset),
+ typeof(TimeSpan),
+ typeof(string),
+ typeof(Guid),
+ typeof(SqlBinary),
+ typeof(SqlBoolean),
+ typeof(SqlByte),
+ typeof(SqlBytes),
+ typeof(SqlChars),
+ typeof(SqlDateTime),
+ typeof(SqlDecimal),
+ typeof(SqlDouble),
+ typeof(SqlGuid),
+ typeof(SqlInt16),
+ typeof(SqlInt32),
+ typeof(SqlInt64),
+ typeof(SqlMoney),
+ typeof(SqlSingle),
+ typeof(SqlString),
+
+ /* non-primitives, but common */
+ typeof(object),
+ typeof(Type),
+ typeof(BigInteger),
+ typeof(Uri),
+
+ /* frequently used System.Drawing types */
+ typeof(Color),
+ typeof(Point),
+ typeof(PointF),
+ typeof(Rectangle),
+ typeof(RectangleF),
+ typeof(Size),
+ typeof(SizeF),
+
+ /* to test that enums are allowed */
+ typeof(StringComparison),
+ };
+
+ public static IEnumerable<object[]> AllowedTypes()
+ {
+ foreach (Type type in _alwaysAllowedTypes)
+ {
+ yield return new object[] { type }; // T
+ yield return new object[] { type.MakeArrayType() }; // T[] (SZArray)
+ yield return new object[] { type.MakeArrayType().MakeArrayType() }; // T[][] (jagged array)
+ yield return new object[] { typeof(List<>).MakeGenericType(type) }; // List<T>
+ }
+ }
+
+ public static IEnumerable<object[]> ForbiddenTypes()
+ {
+ // StringBuilder isn't in the allow list
+
+ yield return new object[] { typeof(StringBuilder) };
+ yield return new object[] { typeof(StringBuilder[]) };
+
+ // multi-dim arrays and non-sz arrays are forbidden
+
+ yield return new object[] { typeof(int[,]) };
+ yield return new object[] { Array.CreateInstance(typeof(int), new[] { 1 }, new[] { 1 }).GetType() };
+
+ // HashSet<T> isn't in the allow list
+
+ yield return new object[] { typeof(HashSet<int>) };
+
+ // DataSet / DataTable / SqlXml aren't in the allow list
+
+ yield return new object[] { typeof(DataSet) };
+ yield return new object[] { typeof(DataTable) };
+ yield return new object[] { typeof(SqlXml) };
+
+ // Enum, Array, and other base types aren't allowed
+
+ yield return new object[] { typeof(Enum) };
+ yield return new object[] { typeof(Array) };
+ yield return new object[] { typeof(ValueType) };
+ yield return new object[] { typeof(void) };
+ }
+
+ [Theory]
+ [MemberData(nameof(AllowedTypes))]
+ public void DataTable_ReadXml_AllowsKnownTypes(Type type)
+ {
+ // Arrange
+
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("MyColumn", type);
+
+ string asXml = WriteXmlWithSchema(table.WriteXml);
+
+ // Act
+
+ table = ReadXml<DataTable>(asXml);
+
+ // Assert
+
+ Assert.Equal("MyTable", table.TableName);
+ Assert.Equal(1, table.Columns.Count);
+ Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+ Assert.Equal(type, table.Columns[0].DataType);
+ }
+
+ [Theory]
+ [MemberData(nameof(ForbiddenTypes))]
+ public void DataTable_ReadXml_ForbidsUnknownTypes(Type type)
+ {
+ // Arrange
+
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("MyColumn", type);
+
+ string asXml = WriteXmlWithSchema(table.WriteXml);
+
+ // Act & assert
+
+ Assert.Throws<InvalidOperationException>(() => ReadXml<DataTable>(asXml));
+ }
+
+ [Fact]
+ public void DataTable_ReadXml_HandlesXmlSerializableTypes()
+ {
+ // Arrange
+
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("MyColumn", typeof(object));
+ table.Rows.Add(new MyXmlSerializableClass());
+
+ string asXml = WriteXmlWithSchema(table.WriteXml, XmlWriteMode.IgnoreSchema);
+
+ // Act & assert
+ // MyXmlSerializableClass shouldn't be allowed as a member for a column
+ // typed as 'object'.
+
+ table.Rows.Clear();
+ Assert.Throws<InvalidOperationException>(() => table.ReadXml(new StringReader(asXml)));
+ }
+
+ [Theory]
+ [MemberData(nameof(ForbiddenTypes))]
+ public void DataTable_ReadXmlSchema_AllowsUnknownTypes(Type type)
+ {
+ // Arrange
+
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("MyColumn", type);
+
+ string asXml = WriteXmlWithSchema(table.WriteXml);
+
+ // Act
+
+ table = new DataTable();
+ table.ReadXmlSchema(new StringReader(asXml));
+
+ // Assert
+
+ Assert.Equal("MyTable", table.TableName);
+ Assert.Equal(1, table.Columns.Count);
+ Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+ Assert.Equal(type, table.Columns[0].DataType);
+ }
+
+ [Fact]
+ public void DataTable_HonorsGloballyDefinedAllowList()
+ {
+ // Arrange
+
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("MyColumn", typeof(MyCustomClass));
+
+ string asXml = WriteXmlWithSchema(table.WriteXml);
+
+ // Act & assert 1
+ // First call should fail since MyCustomClass not allowed
+
+ Assert.Throws<InvalidOperationException>(() => ReadXml<DataTable>(asXml));
+
+ // Act & assert 2
+ // Deserialization should succeed since it's now in the allow list
+
+ try
+ {
+ AppDomain.CurrentDomain.SetData(AppDomainDataSetDefaultAllowedTypesKey, new Type[]
+ {
+ typeof(MyCustomClass)
+ });
+
+ table = ReadXml<DataTable>(asXml);
+
+ Assert.Equal("MyTable", table.TableName);
+ Assert.Equal(1, table.Columns.Count);
+ Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+ Assert.Equal(typeof(MyCustomClass), table.Columns[0].DataType);
+ }
+ finally
+ {
+ AppDomain.CurrentDomain.SetData(AppDomainDataSetDefaultAllowedTypesKey, null);
+ }
+ }
+
+ [Fact]
+ public void DataColumn_ConvertExpression_SubjectToAllowList_Success()
+ {
+ // Arrange
+
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("MyColumn", typeof(object), "CONVERT('42', 'System.Int32')");
+
+ string asXml = WriteXmlWithSchema(table.WriteXml);
+
+ // Act
+
+ table = ReadXml<DataTable>(asXml);
+
+ // Assert
+
+ Assert.Equal("MyTable", table.TableName);
+ Assert.Equal(1, table.Columns.Count);
+ Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+ Assert.Equal(typeof(object), table.Columns[0].DataType);
+ Assert.Equal("CONVERT('42', 'System.Int32')", table.Columns[0].Expression);
+ }
+
+ [Fact]
+ public void DataColumn_ConvertExpression_SubjectToAllowList_Failure()
+ {
+ // Arrange
+
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("ColumnA", typeof(object));
+ table.Columns.Add("ColumnB", typeof(object), "CONVERT(ColumnA, 'System.Text.StringBuilder')");
+
+ string asXml = WriteXmlWithSchema(table.WriteXml);
+
+ // Act
+ // 'StringBuilder' isn't in the allow list, but we're not yet hydrating the Type
+ // object so we won't check it just yet.
+
+ table = ReadXml<DataTable>(asXml);
+
+ // Assert - the CONVERT function node should have captured the active allow list
+ // at construction and should apply it now.
+
+ Assert.Throws<InvalidOperationException>(() => table.Rows.Add(new StringBuilder()));
+ }
+
+ [Theory]
+ [MemberData(nameof(AllowedTypes))]
+ public void DataSet_ReadXml_AllowsKnownTypes(Type type)
+ {
+ // Arrange
+
+ DataSet set = new DataSet("MySet");
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("MyColumn", type);
+ set.Tables.Add(table);
+
+ string asXml = WriteXmlWithSchema(set.WriteXml);
+
+ // Act
+
+ table = null;
+ set = ReadXml<DataSet>(asXml);
+
+ // Assert
+
+ Assert.Equal("MySet", set.DataSetName);
+ Assert.Equal(1, set.Tables.Count);
+
+ table = set.Tables[0];
+ Assert.Equal("MyTable", table.TableName);
+ Assert.Equal(1, table.Columns.Count);
+ Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+ Assert.Equal(type, table.Columns[0].DataType);
+ }
+
+ [Theory]
+ [MemberData(nameof(ForbiddenTypes))]
+ public void DataSet_ReadXml_ForbidsUnknownTypes(Type type)
+ {
+ // Arrange
+
+ DataSet set = new DataSet("MySet");
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("MyColumn", type);
+ set.Tables.Add(table);
+
+ string asXml = WriteXmlWithSchema(set.WriteXml);
+
+ // Act & assert
+
+ Assert.Throws<InvalidOperationException>(() => ReadXml<DataSet>(asXml));
+ }
+
+ [Theory]
+ [MemberData(nameof(ForbiddenTypes))]
+ public void DataSet_ReadXmlSchema_AllowsUnknownTypes(Type type)
+ {
+ // Arrange
+
+ DataSet set = new DataSet("MySet");
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("MyColumn", type);
+ set.Tables.Add(table);
+
+ string asXml = WriteXmlWithSchema(set.WriteXml);
+
+ // Act
+
+ set = new DataSet();
+ set.ReadXmlSchema(new StringReader(asXml));
+
+ // Assert
+
+ Assert.Equal("MySet", set.DataSetName);
+ Assert.Equal(1, set.Tables.Count);
+
+ table = set.Tables[0];
+ Assert.Equal("MyTable", table.TableName);
+ Assert.Equal(1, table.Columns.Count);
+ Assert.Equal("MyColumn", table.Columns[0].ColumnName);
+ Assert.Equal(type, table.Columns[0].DataType);
+ }
+
+ [Fact]
+ public void SerializationGuard_BlocksFileAccessOnDeserialize()
+ {
+ // Arrange
+
+ DataTable table = new DataTable("MyTable");
+ table.Columns.Add("MyColumn", typeof(MyCustomClassThatWritesToAFile));
+ table.Rows.Add(new MyCustomClassThatWritesToAFile());
+
+ string asXml = WriteXmlWithSchema(table.WriteXml);
+ table.Rows.Clear();
+
+ // Act & assert
+
+ Assert.Throws<SerializationException>(() => table.ReadXml(new StringReader(asXml)));
+ }
+
+ private static string WriteXmlWithSchema(Action<TextWriter, XmlWriteMode> writeMethod, XmlWriteMode xmlWriteMode = XmlWriteMode.WriteSchema)
+ {
+ StringWriter writer = new StringWriter();
+ writeMethod(writer, xmlWriteMode);
+ return writer.ToString();
+ }
+
+ private static T ReadXml<T>(string xml) where T : IXmlSerializable, new()
+ {
+ T newObj = new T();
+ newObj.ReadXml(new XmlTextReader(new StringReader(xml)) { XmlResolver = null }); // suppress DTDs, same as runtime code
+ return newObj;
+ }
+
+ private sealed class MyCustomClass
+ {
+ }
+
+ public sealed class MyXmlSerializableClass : IXmlSerializable
+ {
+ public XmlSchema GetSchema()
+ {
+ return null;
+ }
+
+ public void ReadXml(XmlReader reader)
+ {
+ return; // no-op
+ }
+
+ public void WriteXml(XmlWriter writer)
+ {
+ writer.WriteElementString("MyElement", "MyValue");
+ }
+ }
+
+ private sealed class MyCustomClassThatWritesToAFile : IXmlSerializable
+ {
+ public XmlSchema GetSchema()
+ {
+ return null;
+ }
+
+ public void ReadXml(XmlReader reader)
+ {
+ // This should be called within a Serialization Guard scope, so the file write
+ // should fail.
+
+ string tempPath = Path.GetTempFileName();
+ File.WriteAllText(tempPath, "This better not be written...");
+ File.Delete(tempPath);
+ throw new XunitException("Unreachable code (SerializationGuard should have kicked in)");
+ }
+
+ public void WriteXml(XmlWriter writer)
+ {
+ writer.WriteElementString("MyElement", "MyValue");
+ }
+ }
+ }
+}