diff options
author | Levi Broderick <GrabYourPitchforks@users.noreply.github.com> | 2020-07-15 09:12:10 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-15 09:12:10 +0300 |
commit | 53976d38b1bd6917b8fa4d1dd4f009728ece3adb (patch) | |
tree | d1ffc5e72f6e48969b5eccfaae9620a0fd10b030 | |
parent | fe9e53e14694e281818945c28ea7468053248fb7 (diff) |
[release/5.0-preview7] Disallow unrestricted polymorphic deserialization in DataSet (#39314)v5.0.0-preview.7.20364.11
Fixes CVE-2020-1147
https://portal.msrc.microsoft.com/en-us/security-guidance/advisory/CVE-2020-1147
See also https://go.microsoft.com/fwlink/?linkid=2132227.
12 files changed, 824 insertions, 13 deletions
diff --git a/src/libraries/System.Data.Common/src/Resources/Strings.resx b/src/libraries/System.Data.Common/src/Resources/Strings.resx index 4535a6944bf..671bf490bc1 100644 --- a/src/libraries/System.Data.Common/src/Resources/Strings.resx +++ b/src/libraries/System.Data.Common/src/Resources/Strings.resx @@ -165,6 +165,7 @@ <data name="Data_ArgumentOutOfRange" xml:space="preserve"><value>'{0}' argument is out of range.</value></data> <data name="Data_ArgumentNull" xml:space="preserve"><value>'{0}' argument cannot be null.</value></data> <data name="Data_ArgumentContainsNull" xml:space="preserve"><value>'{0}' argument contains null value.</value></data> + <data name="Data_TypeNotAllowed" xml:space="preserve"><value>Type '{0}' is not allowed here. See https://go.microsoft.com/fwlink/?linkid=2132227 for more details.</value></data> <data name="DataColumns_OutOfRange" xml:space="preserve"><value>Cannot find column {0}.</value></data> <data name="DataColumns_Add1" xml:space="preserve"><value>Column '{0}' already belongs to this DataTable.</value></data> <data name="DataColumns_Add2" xml:space="preserve"><value>Column '{0}' already belongs to another DataTable.</value></data> diff --git a/src/libraries/System.Data.Common/src/System.Data.Common.csproj b/src/libraries/System.Data.Common/src/System.Data.Common.csproj index 20091047d6b..7beb59ac3a8 100644 --- a/src/libraries/System.Data.Common/src/System.Data.Common.csproj +++ b/src/libraries/System.Data.Common/src/System.Data.Common.csproj @@ -2,7 +2,7 @@ <PropertyGroup> <AssemblyName>System.Data.Common</AssemblyName> <AllowUnsafeBlocks>true</AllowUnsafeBlocks> - <TargetFrameworks>$(NetCoreAppCurrent)</TargetFrameworks> + <TargetFrameworks>$(NetCoreAppCurrent)-Windows_NT;$(NetCoreAppCurrent)-Unix</TargetFrameworks> </PropertyGroup> <ItemGroup> <Compile Include="System.Data.Common.TypeForwards.cs" /> @@ -123,6 +123,10 @@ <Compile Include="System\Data\KeyRestrictionBehavior.cs" /> <Compile Include="System\Data\LinqDataView.cs" /> <Compile Include="System\Data\LoadOption.cs" /> + <Compile Include="System\Data\LocalAppContextSwitches.cs" /> + <Compile Include="$(CommonPath)System\LocalAppContextSwitches.Common.cs"> + <Link>Common\System\LocalAppContextSwitches.Common.cs</Link> + </Compile> <Compile Include="System\Data\MappingType.cs" /> <Compile Include="System\Data\MergeFailedEvent.cs" /> <Compile Include="System\Data\MergeFailedEventHandler.cs" /> @@ -156,6 +160,7 @@ <Compile Include="System\Data\StrongTypingException.cs" /> <Compile Include="System\Data\TypedTableBase.cs" /> <Compile Include="System\Data\TypedTableBaseExtensions.cs" /> + <Compile Include="System\Data\TypeLimiter.cs" /> <Compile Include="System\Data\UniqueConstraint.cs" /> <Compile Include="System\Data\UpdateRowSource.cs" /> <Compile Include="System\Data\Common\UInt64Storage.cs" /> @@ -295,25 +300,23 @@ <Compile Include="System\Data\ProviderBase\SchemaMapping.cs" /> </ItemGroup> <ItemGroup> - <Reference Include="System.Collections" /> + <ProjectReference Include="$(CoreLibProject)" /> + <ProjectReference Include="..\..\System.Collections\src\System.Collections.csproj" /> + <ProjectReference Include="..\..\System.Collections.NonGeneric\src\System.Collections.NonGeneric.csproj" /> + <ProjectReference Include="..\..\System.ComponentModel.TypeConverter\src\System.ComponentModel.TypeConverter.csproj" /> + <ProjectReference Include="..\..\System.Runtime\src\System.Runtime.csproj" /> + <ProjectReference Include="..\..\System.Runtime.Extensions\src\System.Runtime.Extensions.csproj" /> + <ProjectReference Include="..\..\System.Private.Uri\src\System.Private.Uri.csproj" /> <Reference Include="System.Collections.Concurrent" /> - <Reference Include="System.Collections.NonGeneric" /> <Reference Include="System.ComponentModel" /> <Reference Include="System.ComponentModel.Primitives" /> - <Reference Include="System.ComponentModel.TypeConverter" /> - <Reference Include="System.Diagnostics.Tracing" /> + <Reference Include="System.Drawing.Primitives" /> <Reference Include="System.Linq" /> <Reference Include="System.Linq.Expressions" /> - <Reference Include="System.Memory" /> <Reference Include="System.ObjectModel" /> - <Reference Include="System.Runtime" /> - <Reference Include="System.Runtime.Extensions" /> <Reference Include="System.Runtime.Numerics" /> <Reference Include="System.Runtime.Serialization.Formatters" /> - <Reference Include="System.Text.Encoding.Extensions" /> <Reference Include="System.Text.RegularExpressions" /> - <Reference Include="System.Threading" /> - <Reference Include="System.Threading.Thread" /> <Reference Include="System.Transactions.Local" /> <Reference Include="System.Xml.ReaderWriter" /> <Reference Include="System.Xml.XmlSerializer" /> diff --git a/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs b/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs index ffe4960cce7..dc3fe053356 100644 --- a/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs +++ b/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs @@ -406,6 +406,9 @@ namespace System.Data.Common if (type == typeof(object)) throw ExceptionBuilder.CanNotDeserializeObjectType(); + + TypeLimiter.EnsureTypeIsAllowed(type); + if (!isBaseCLRType) { retValue = System.Activator.CreateInstance(type, true); diff --git a/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs b/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs index d44ec1e0b88..560f6ecd7e7 100644 --- a/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs +++ b/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs @@ -143,6 +143,7 @@ namespace System.Data private void UpdateColumnType(Type type, StorageType typeCode) { + TypeLimiter.EnsureTypeIsAllowed(type); _dataType = type; _storageType = typeCode; if (StorageType.DateTime != typeCode) diff --git a/src/libraries/System.Data.Common/src/System/Data/DataException.cs b/src/libraries/System.Data.Common/src/System/Data/DataException.cs index 7aa84800747..aede6f282f1 100644 --- a/src/libraries/System.Data.Common/src/System/Data/DataException.cs +++ b/src/libraries/System.Data.Common/src/System/Data/DataException.cs @@ -350,6 +350,7 @@ namespace System.Data public static Exception ArgumentOutOfRange(string paramName) => _ArgumentOutOfRange(paramName, SR.Format(SR.Data_ArgumentOutOfRange, paramName)); public static Exception BadObjectPropertyAccess(string error) => _InvalidOperation(SR.Format(SR.DataConstraint_BadObjectPropertyAccess, error)); public static Exception ArgumentContainsNull(string paramName) => _Argument(paramName, SR.Format(SR.Data_ArgumentContainsNull, paramName)); + public static Exception TypeNotAllowed(Type type) => _InvalidOperation(SR.Format(SR.Data_TypeNotAllowed, type.AssemblyQualifiedName)); // diff --git a/src/libraries/System.Data.Common/src/System/Data/DataSet.cs b/src/libraries/System.Data.Common/src/System/Data/DataSet.cs index f4d6ad887d5..c09b35203a9 100644 --- a/src/libraries/System.Data.Common/src/System/Data/DataSet.cs +++ b/src/libraries/System.Data.Common/src/System/Data/DataSet.cs @@ -1961,9 +1961,11 @@ namespace System.Data internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving) { + IDisposable restrictedScope = null; long logScopeId = DataCommonEventSource.Log.EnterScope("<ds.DataSet.ReadXml|INFO> {0}, denyResolving={1}", ObjectID, denyResolving); try { + restrictedScope = TypeLimiter.EnterRestrictedScope(this); DataTable.DSRowDiffIdUsageSection rowDiffIdUsage = default; try { @@ -2231,6 +2233,7 @@ namespace System.Data } finally { + restrictedScope?.Dispose(); DataCommonEventSource.Log.ExitScope(logScopeId); } } @@ -2467,9 +2470,11 @@ namespace System.Data internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolving) { + IDisposable restictedScope = null; long logScopeId = DataCommonEventSource.Log.EnterScope("<ds.DataSet.ReadXml|INFO> {0}, mode={1}, denyResolving={2}", ObjectID, mode, denyResolving); try { + restictedScope = TypeLimiter.EnterRestrictedScope(this); XmlReadMode ret = mode; if (reader == null) @@ -2711,6 +2716,7 @@ namespace System.Data } finally { + restictedScope?.Dispose(); DataCommonEventSource.Log.ExitScope(logScopeId); } } diff --git a/src/libraries/System.Data.Common/src/System/Data/DataTable.cs b/src/libraries/System.Data.Common/src/System/Data/DataTable.cs index 917908ec2a4..5b2fa268c94 100644 --- a/src/libraries/System.Data.Common/src/System/Data/DataTable.cs +++ b/src/libraries/System.Data.Common/src/System/Data/DataTable.cs @@ -5659,9 +5659,11 @@ namespace System.Data internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving) { + IDisposable restrictedScope = null; long logScopeId = DataCommonEventSource.Log.EnterScope("<ds.DataTable.ReadXml|INFO> {0}, denyResolving={1}", ObjectID, denyResolving); try { + restrictedScope = TypeLimiter.EnterRestrictedScope(this); RowDiffIdUsageSection rowDiffIdUsage = default; try { @@ -5896,15 +5898,18 @@ namespace System.Data } finally { + restrictedScope?.Dispose(); DataCommonEventSource.Log.ExitScope(logScopeId); } } internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolving) { + IDisposable restrictedScope = null; RowDiffIdUsageSection rowDiffIdUsage = default; try { + restrictedScope = TypeLimiter.EnterRestrictedScope(this); bool fSchemaFound = false; bool fDataFound = false; bool fIsXdr = false; @@ -6190,6 +6195,7 @@ namespace System.Data } finally { + restrictedScope?.Dispose(); // prepare and cleanup rowDiffId hashtable rowDiffIdUsage.Cleanup(); } diff --git a/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs b/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs index 33aa941e303..472cd9be2b7 100644 --- a/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs +++ b/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Data.Common; using System.Data.SqlTypes; using System.Diagnostics; +using System.Runtime.Serialization; namespace System.Data { @@ -16,6 +17,7 @@ namespace System.Data internal int _argumentCount = 0; internal const int initialCapacity = 1; internal ExpressionNode[] _arguments; + private readonly TypeLimiter _capturedLimiter = null; private static readonly Function[] s_funcs = new Function[] { new Function("Abs", FunctionId.Abs, typeof(object), true, false, 1, typeof(object), null, null), @@ -40,6 +42,12 @@ namespace System.Data internal FunctionNode(DataTable table, string name) : base(table) { + // Because FunctionNode instances are created eagerly but evaluated lazily, + // we need to capture the deserialization scope here. The scope could be + // null if no deserialization is in progress. + + _capturedLimiter = TypeLimiter.Capture(); + _name = name; for (int i = 0; i < s_funcs.Length; i++) { @@ -289,6 +297,11 @@ namespace System.Data throw ExprException.InvalidType(typeName); } + // ReadXml might not be on the current call stack. So we'll use the TypeLimiter + // that was captured when this FunctionNode instance was created. + + TypeLimiter.EnsureTypeIsAllowed(dataType, _capturedLimiter); + return dataType; } @@ -494,10 +507,17 @@ namespace System.Data { return SqlConvert.ChangeType2((decimal)SqlConvert.ChangeType2(argumentValues[0], StorageType.Decimal, typeof(decimal), FormatProvider), mytype, type, FormatProvider); } - return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider); } - return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider); + // The Convert function can be called lazily, outside of a previous Serialization Guard scope. + // If there was a type limiter scope on the stack at the time this Convert function was created, + // we must manually re-enter the Serialization Guard scope. + + DeserializationToken deserializationToken = (_capturedLimiter != null) ? SerializationInfo.StartDeserialization() : default; + using (deserializationToken) + { + return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider); + } } return argumentValues[0]; diff --git a/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs b/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs new file mode 100644 index 00000000000..42afa2b8fed --- /dev/null +++ b/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; + +namespace System +{ + internal static partial class LocalAppContextSwitches + { + private static int s_allowArbitraryTypeInstantiation; + public static bool AllowArbitraryTypeInstantiation + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => GetCachedSwitchValue("Switch.System.Data.AllowArbitraryDataSetTypeInstantiation", ref s_allowArbitraryTypeInstantiation); + } + } +} diff --git a/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs b/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs new file mode 100644 index 00000000000..1ff77cb9952 --- /dev/null +++ b/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs @@ -0,0 +1,305 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Drawing; +using System.Linq; +using System.Numerics; +using System.Runtime.Serialization; + +namespace System.Data +{ + internal sealed class TypeLimiter + { + [ThreadStatic] + private static Scope s_activeScope; + + private Scope m_instanceScope; + + private const string AppDomainDataSetDefaultAllowedTypesKey = "System.Data.DataSetDefaultAllowedTypes"; + + private TypeLimiter(Scope scope) + { + Debug.Assert(scope != null); + m_instanceScope = scope; + } + + private static bool IsTypeLimitingDisabled + => LocalAppContextSwitches.AllowArbitraryTypeInstantiation; + + /// <summary> + /// Captures the current <see cref="TypeLimiter"/> instance so that future + /// type checks can be performed against the allow list that was active during + /// the current deserialization scope. + /// </summary> + /// <remarks> + /// Returns null if no limiter is active. + /// </remarks> + public static TypeLimiter Capture() + { + Scope activeScope = s_activeScope; + return (activeScope != null) ? new TypeLimiter(activeScope) : null; + } + + /// <summary> + /// Ensures the requested type is allowed by the rules of the active + /// deserialization scope. If a captured scope is provided, we'll use + /// that previously captured scope rather than the thread-static active + /// scope. + /// </summary> + /// <exception cref="InvalidOperationException"> + /// If <paramref name="type"/> is not allowed. + /// </exception> + public static void EnsureTypeIsAllowed(Type type, TypeLimiter capturedLimiter = null) + { + if (type is null) + { + return; // nothing to check + } + + Scope capturedScope = capturedLimiter?.m_instanceScope ?? s_activeScope; + if (capturedScope is null) + { + return; // we're not in a restricted scope + } + + if (capturedScope.IsAllowedType(type)) + { + return; // type was explicitly allowed + } + + // We encountered a type that wasn't in the allow list. + // Throw an exception to fail the current operation. + + throw ExceptionBuilder.TypeNotAllowed(type); + } + + public static IDisposable EnterRestrictedScope(DataSet dataSet) + { + if (IsTypeLimitingDisabled) + { + return null; // protections aren't enabled + } + + Scope newScope = new Scope(s_activeScope, GetPreviouslyDeclaredDataTypes(dataSet)); + s_activeScope = newScope; + return newScope; + } + + public static IDisposable EnterRestrictedScope(DataTable dataTable) + { + if (IsTypeLimitingDisabled) + { + return null; // protections aren't enabled + } + + Scope newScope = new Scope(s_activeScope, GetPreviouslyDeclaredDataTypes(dataTable)); + s_activeScope = newScope; + return newScope; + } + + /// <summary> + /// Given a <see cref="DataTable"/>, returns all of the <see cref="DataColumn.DataType"/> + /// values declared on the instance. + /// </summary> + private static IEnumerable<Type> GetPreviouslyDeclaredDataTypes(DataTable dataTable) + { + return (dataTable != null) + ? dataTable.Columns.Cast<DataColumn>().Select(column => column.DataType) + : Enumerable.Empty<Type>(); + } + + /// <summary> + /// Given a <see cref="DataSet"/>, returns all of the <see cref="DataColumn.DataType"/> + /// values declared on the instance. + /// </summary> + private static IEnumerable<Type> GetPreviouslyDeclaredDataTypes(DataSet dataSet) + { + return (dataSet != null) + ? dataSet.Tables.Cast<DataTable>().SelectMany(table => GetPreviouslyDeclaredDataTypes(table)) + : Enumerable.Empty<Type>(); + } + + private sealed class Scope : IDisposable + { + /// <summary> + /// Types which are always allowed, unconditionally. + /// </summary> + private static readonly HashSet<Type> s_allowedTypes = new HashSet<Type>() + { + /* primitives */ + typeof(bool), + typeof(char), + typeof(sbyte), + typeof(byte), + typeof(short), + typeof(ushort), + typeof(int), + typeof(uint), + typeof(long), + typeof(ulong), + typeof(float), + typeof(double), + typeof(decimal), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(TimeSpan), + typeof(string), + typeof(Guid), + typeof(SqlBinary), + typeof(SqlBoolean), + typeof(SqlByte), + typeof(SqlBytes), + typeof(SqlChars), + typeof(SqlDateTime), + typeof(SqlDecimal), + typeof(SqlDouble), + typeof(SqlGuid), + typeof(SqlInt16), + typeof(SqlInt32), + typeof(SqlInt64), + typeof(SqlMoney), + typeof(SqlSingle), + typeof(SqlString), + + /* non-primitives, but common */ + typeof(object), + typeof(Type), + typeof(BigInteger), + typeof(Uri), + + /* frequently used System.Drawing types */ + typeof(Color), + typeof(Point), + typeof(PointF), + typeof(Rectangle), + typeof(RectangleF), + typeof(Size), + typeof(SizeF), + }; + + /// <summary> + /// Types which are allowed within the context of this scope. + /// </summary> + private HashSet<Type> m_allowedTypes; + + /// <summary> + /// This thread's previous scope. + /// </summary> + private readonly Scope m_previousScope; + + /// <summary> + /// The Serialization Guard token associated with this scope. + /// </summary> + private readonly DeserializationToken m_deserializationToken; + + internal Scope(Scope previousScope, IEnumerable<Type> allowedTypes) + { + Debug.Assert(allowedTypes != null); + + m_previousScope = previousScope; + m_allowedTypes = new HashSet<Type>(allowedTypes.Where(type => type != null)); + m_deserializationToken = SerializationInfo.StartDeserialization(); + } + + public void Dispose() + { + if (this != s_activeScope) + { + // Stacks should never be popped out of order. + // We want to trap this condition in production. + Debug.Fail("Scope was popped out of order."); + throw new ObjectDisposedException(GetType().FullName); + } + + m_deserializationToken.Dispose(); // it's a readonly struct, but Dispose still works properly + s_activeScope = m_previousScope; // could be null + } + + public bool IsAllowedType(Type type) + { + Debug.Assert(type != null); + + // Is the incoming type unconditionally allowed? + + if (IsTypeUnconditionallyAllowed(type)) + { + return true; + } + + // The incoming type is allowed if the current scope or any nested inner + // scope allowed it. + + for (Scope currentScope = this; currentScope != null; currentScope = currentScope.m_previousScope) + { + if (currentScope.m_allowedTypes.Contains(type)) + { + return true; + } + } + + // Did the application programmatically allow this type to be deserialized? + + Type[] appDomainAllowedTypes = (Type[])AppDomain.CurrentDomain.GetData(AppDomainDataSetDefaultAllowedTypesKey); + if (appDomainAllowedTypes != null) + { + for (int i = 0; i < appDomainAllowedTypes.Length; i++) + { + if (type == appDomainAllowedTypes[i]) + { + return true; + } + } + } + + // All checks failed + + return false; + } + + private static bool IsTypeUnconditionallyAllowed(Type type) + { + TryAgain: + Debug.Assert(type != null); + + // Check the list of unconditionally allowed types. + + if (s_allowedTypes.Contains(type)) + { + return true; + } + + // Enums are also always allowed, as we optimistically assume the app + // developer didn't define a dangerous enum type. + + if (type.IsEnum) + { + return true; + } + + // Allow single-dimensional arrays of any unconditionally allowed type. + + if (type.IsSZArray) + { + type = type.GetElementType(); + goto TryAgain; + } + + // Allow generic lists of any unconditionally allowed type. + + if (type.IsGenericType && !type.IsGenericTypeDefinition && type.GetGenericTypeDefinition() == typeof(List<>)) + { + type = type.GetGenericArguments()[0]; + goto TryAgain; + } + + // All checks failed. + + return false; + } + } + } +} diff --git a/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj b/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj index 4d0933db75b..9742209ccf9 100644 --- a/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj +++ b/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj @@ -112,6 +112,7 @@ <Compile Include="System\Data\VersionNotFoundException.cs" /> <Compile Include="System\Data\XmlDataLoaderTest.cs" /> <Compile Include="System\Data\XmlDataReaderTest.cs" /> + <Compile Include="System\Data\RestrictedTypeHandlingTests.cs" /> <Compile Include="System\Xml\XmlDataDocumentTests.cs" /> <Compile Include="$(CommonTestPath)System\Diagnostics\Tracing\TestEventListener.cs" Link="Common\System\Diagnostics\Tracing\TestEventListener.cs" /> diff --git a/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs b/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs new file mode 100644 index 00000000000..3f76cebc805 --- /dev/null +++ b/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs @@ -0,0 +1,446 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Drawing; +using System.IO; +using System.Numerics; +using System.Runtime.Serialization; +using System.Text; +using System.Xml; +using System.Xml.Schema; +using System.Xml.Serialization; +using Xunit; +using Xunit.Sdk; + +namespace System.Data.Tests +{ + // !! Important !! + // These tests manipulate global state, so they cannot be run in parallel with one another. + // We rely on xunit's default behavior of not parallelizing unit tests declared on the same + // test class: see https://xunit.net/docs/running-tests-in-parallel.html. + public class RestrictedTypeHandlingTests + { + private const string AppDomainDataSetDefaultAllowedTypesKey = "System.Data.DataSetDefaultAllowedTypes"; + + private static readonly Type[] _alwaysAllowedTypes = new Type[] + { + /* primitives */ + typeof(bool), + typeof(char), + typeof(sbyte), + typeof(byte), + typeof(short), + typeof(ushort), + typeof(int), + typeof(uint), + typeof(long), + typeof(ulong), + typeof(float), + typeof(double), + typeof(decimal), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(TimeSpan), + typeof(string), + typeof(Guid), + typeof(SqlBinary), + typeof(SqlBoolean), + typeof(SqlByte), + typeof(SqlBytes), + typeof(SqlChars), + typeof(SqlDateTime), + typeof(SqlDecimal), + typeof(SqlDouble), + typeof(SqlGuid), + typeof(SqlInt16), + typeof(SqlInt32), + typeof(SqlInt64), + typeof(SqlMoney), + typeof(SqlSingle), + typeof(SqlString), + + /* non-primitives, but common */ + typeof(object), + typeof(Type), + typeof(BigInteger), + typeof(Uri), + + /* frequently used System.Drawing types */ + typeof(Color), + typeof(Point), + typeof(PointF), + typeof(Rectangle), + typeof(RectangleF), + typeof(Size), + typeof(SizeF), + + /* to test that enums are allowed */ + typeof(StringComparison), + }; + + public static IEnumerable<object[]> AllowedTypes() + { + foreach (Type type in _alwaysAllowedTypes) + { + yield return new object[] { type }; // T + yield return new object[] { type.MakeArrayType() }; // T[] (SZArray) + yield return new object[] { type.MakeArrayType().MakeArrayType() }; // T[][] (jagged array) + yield return new object[] { typeof(List<>).MakeGenericType(type) }; // List<T> + } + } + + public static IEnumerable<object[]> ForbiddenTypes() + { + // StringBuilder isn't in the allow list + + yield return new object[] { typeof(StringBuilder) }; + yield return new object[] { typeof(StringBuilder[]) }; + + // multi-dim arrays and non-sz arrays are forbidden + + yield return new object[] { typeof(int[,]) }; + yield return new object[] { Array.CreateInstance(typeof(int), new[] { 1 }, new[] { 1 }).GetType() }; + + // HashSet<T> isn't in the allow list + + yield return new object[] { typeof(HashSet<int>) }; + + // DataSet / DataTable / SqlXml aren't in the allow list + + yield return new object[] { typeof(DataSet) }; + yield return new object[] { typeof(DataTable) }; + yield return new object[] { typeof(SqlXml) }; + + // Enum, Array, and other base types aren't allowed + + yield return new object[] { typeof(Enum) }; + yield return new object[] { typeof(Array) }; + yield return new object[] { typeof(ValueType) }; + yield return new object[] { typeof(void) }; + } + + [Theory] + [MemberData(nameof(AllowedTypes))] + public void DataTable_ReadXml_AllowsKnownTypes(Type type) + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act + + table = ReadXml<DataTable>(asXml); + + // Assert + + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(type, table.Columns[0].DataType); + } + + [Theory] + [MemberData(nameof(ForbiddenTypes))] + public void DataTable_ReadXml_ForbidsUnknownTypes(Type type) + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act & assert + + Assert.Throws<InvalidOperationException>(() => ReadXml<DataTable>(asXml)); + } + + [Fact] + public void DataTable_ReadXml_HandlesXmlSerializableTypes() + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", typeof(object)); + table.Rows.Add(new MyXmlSerializableClass()); + + string asXml = WriteXmlWithSchema(table.WriteXml, XmlWriteMode.IgnoreSchema); + + // Act & assert + // MyXmlSerializableClass shouldn't be allowed as a member for a column + // typed as 'object'. + + table.Rows.Clear(); + Assert.Throws<InvalidOperationException>(() => table.ReadXml(new StringReader(asXml))); + } + + [Theory] + [MemberData(nameof(ForbiddenTypes))] + public void DataTable_ReadXmlSchema_AllowsUnknownTypes(Type type) + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act + + table = new DataTable(); + table.ReadXmlSchema(new StringReader(asXml)); + + // Assert + + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(type, table.Columns[0].DataType); + } + + [Fact] + public void DataTable_HonorsGloballyDefinedAllowList() + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", typeof(MyCustomClass)); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act & assert 1 + // First call should fail since MyCustomClass not allowed + + Assert.Throws<InvalidOperationException>(() => ReadXml<DataTable>(asXml)); + + // Act & assert 2 + // Deserialization should succeed since it's now in the allow list + + try + { + AppDomain.CurrentDomain.SetData(AppDomainDataSetDefaultAllowedTypesKey, new Type[] + { + typeof(MyCustomClass) + }); + + table = ReadXml<DataTable>(asXml); + + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(typeof(MyCustomClass), table.Columns[0].DataType); + } + finally + { + AppDomain.CurrentDomain.SetData(AppDomainDataSetDefaultAllowedTypesKey, null); + } + } + + [Fact] + public void DataColumn_ConvertExpression_SubjectToAllowList_Success() + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", typeof(object), "CONVERT('42', 'System.Int32')"); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act + + table = ReadXml<DataTable>(asXml); + + // Assert + + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(typeof(object), table.Columns[0].DataType); + Assert.Equal("CONVERT('42', 'System.Int32')", table.Columns[0].Expression); + } + + [Fact] + public void DataColumn_ConvertExpression_SubjectToAllowList_Failure() + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("ColumnA", typeof(object)); + table.Columns.Add("ColumnB", typeof(object), "CONVERT(ColumnA, 'System.Text.StringBuilder')"); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act + // 'StringBuilder' isn't in the allow list, but we're not yet hydrating the Type + // object so we won't check it just yet. + + table = ReadXml<DataTable>(asXml); + + // Assert - the CONVERT function node should have captured the active allow list + // at construction and should apply it now. + + Assert.Throws<InvalidOperationException>(() => table.Rows.Add(new StringBuilder())); + } + + [Theory] + [MemberData(nameof(AllowedTypes))] + public void DataSet_ReadXml_AllowsKnownTypes(Type type) + { + // Arrange + + DataSet set = new DataSet("MySet"); + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + set.Tables.Add(table); + + string asXml = WriteXmlWithSchema(set.WriteXml); + + // Act + + table = null; + set = ReadXml<DataSet>(asXml); + + // Assert + + Assert.Equal("MySet", set.DataSetName); + Assert.Equal(1, set.Tables.Count); + + table = set.Tables[0]; + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(type, table.Columns[0].DataType); + } + + [Theory] + [MemberData(nameof(ForbiddenTypes))] + public void DataSet_ReadXml_ForbidsUnknownTypes(Type type) + { + // Arrange + + DataSet set = new DataSet("MySet"); + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + set.Tables.Add(table); + + string asXml = WriteXmlWithSchema(set.WriteXml); + + // Act & assert + + Assert.Throws<InvalidOperationException>(() => ReadXml<DataSet>(asXml)); + } + + [Theory] + [MemberData(nameof(ForbiddenTypes))] + public void DataSet_ReadXmlSchema_AllowsUnknownTypes(Type type) + { + // Arrange + + DataSet set = new DataSet("MySet"); + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + set.Tables.Add(table); + + string asXml = WriteXmlWithSchema(set.WriteXml); + + // Act + + set = new DataSet(); + set.ReadXmlSchema(new StringReader(asXml)); + + // Assert + + Assert.Equal("MySet", set.DataSetName); + Assert.Equal(1, set.Tables.Count); + + table = set.Tables[0]; + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(type, table.Columns[0].DataType); + } + + [Fact] + public void SerializationGuard_BlocksFileAccessOnDeserialize() + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", typeof(MyCustomClassThatWritesToAFile)); + table.Rows.Add(new MyCustomClassThatWritesToAFile()); + + string asXml = WriteXmlWithSchema(table.WriteXml); + table.Rows.Clear(); + + // Act & assert + + Assert.Throws<SerializationException>(() => table.ReadXml(new StringReader(asXml))); + } + + private static string WriteXmlWithSchema(Action<TextWriter, XmlWriteMode> writeMethod, XmlWriteMode xmlWriteMode = XmlWriteMode.WriteSchema) + { + StringWriter writer = new StringWriter(); + writeMethod(writer, xmlWriteMode); + return writer.ToString(); + } + + private static T ReadXml<T>(string xml) where T : IXmlSerializable, new() + { + T newObj = new T(); + newObj.ReadXml(new XmlTextReader(new StringReader(xml)) { XmlResolver = null }); // suppress DTDs, same as runtime code + return newObj; + } + + private sealed class MyCustomClass + { + } + + public sealed class MyXmlSerializableClass : IXmlSerializable + { + public XmlSchema GetSchema() + { + return null; + } + + public void ReadXml(XmlReader reader) + { + return; // no-op + } + + public void WriteXml(XmlWriter writer) + { + writer.WriteElementString("MyElement", "MyValue"); + } + } + + private sealed class MyCustomClassThatWritesToAFile : IXmlSerializable + { + public XmlSchema GetSchema() + { + return null; + } + + public void ReadXml(XmlReader reader) + { + // This should be called within a Serialization Guard scope, so the file write + // should fail. + + string tempPath = Path.GetTempFileName(); + File.WriteAllText(tempPath, "This better not be written..."); + File.Delete(tempPath); + throw new XunitException("Unreachable code (SerializationGuard should have kicked in)"); + } + + public void WriteXml(XmlWriter writer) + { + writer.WriteElementString("MyElement", "MyValue"); + } + } + } +} |