diff options
author | msftbot[bot] <48340428+msftbot[bot]@users.noreply.github.com> | 2022-09-15 22:12:59 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-15 22:12:59 +0300 |
commit | dbb71daaaef12c231607413102453998141544a1 (patch) | |
tree | 80afa4687fdd50eaab15d37f60859bdfffa3dcae | |
parent | 1ea90cc319c273563b575d73287cd324fd368ad9 (diff) | |
parent | 30a6a397a0692ff76779d05e35021f5b91bd86b2 (diff) |
Merge pull request #43997 from dotnet-maestro-bot/merge/release/7.0-to-main
[automated] Merge branch 'release/7.0' => 'main'
6 files changed, 306 insertions, 11 deletions
diff --git a/src/Mvc/Mvc.Abstractions/src/ModelBinding/ModelStateDictionary.cs b/src/Mvc/Mvc.Abstractions/src/ModelBinding/ModelStateDictionary.cs index 2be3c223ff..a7082e2d5c 100644 --- a/src/Mvc/Mvc.Abstractions/src/ModelBinding/ModelStateDictionary.cs +++ b/src/Mvc/Mvc.Abstractions/src/ModelBinding/ModelStateDictionary.cs @@ -23,6 +23,9 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? /// </summary> public static readonly int DefaultMaxAllowedErrors = 200; + // internal for testing + internal const int DefaultMaxRecursionDepth = 32; + private const char DelimiterDot = '.'; private const char DelimiterOpen = '['; @@ -41,8 +44,18 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? /// Initializes a new instance of the <see cref="ModelStateDictionary"/> class. /// </summary> public ModelStateDictionary(int maxAllowedErrors) + : this(maxAllowedErrors, maxValidationDepth: DefaultMaxRecursionDepth, maxStateDepth: DefaultMaxRecursionDepth) + { + } + + /// <summary> + /// Initializes a new instance of the <see cref="ModelStateDictionary"/> class. + /// </summary> + private ModelStateDictionary(int maxAllowedErrors, int maxValidationDepth, int maxStateDepth) { MaxAllowedErrors = maxAllowedErrors; + MaxValidationDepth = maxValidationDepth; + MaxStateDepth = maxStateDepth; var emptySegment = new StringSegment(buffer: string.Empty); _root = new ModelStateNode(subKey: emptySegment) { @@ -56,7 +69,9 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? /// </summary> /// <param name="dictionary">The <see cref="ModelStateDictionary"/> to copy values from.</param> public ModelStateDictionary(ModelStateDictionary dictionary) - : this(dictionary?.MaxAllowedErrors ?? DefaultMaxAllowedErrors) + : this(dictionary?.MaxAllowedErrors ?? DefaultMaxAllowedErrors, + dictionary?.MaxValidationDepth ?? DefaultMaxRecursionDepth, + dictionary?.MaxStateDepth ?? DefaultMaxRecursionDepth) { if (dictionary == null) { @@ -152,7 +167,7 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? } /// <inheritdoc /> - public ModelValidationState ValidationState => GetValidity(_root) ?? ModelValidationState.Valid; + public ModelValidationState ValidationState => GetValidity(_root, currentDepth: 0) ?? ModelValidationState.Valid; /// <inheritdoc /> public ModelStateEntry? this[string key] @@ -172,6 +187,10 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? // Flag that indicates if TooManyModelErrorException has already been added to this dictionary. private bool HasRecordedMaxModelError { get; set; } + internal int? MaxValidationDepth { get; set; } + + internal int? MaxStateDepth { get; set; } + /// <summary> /// Adds the specified <paramref name="exception"/> to the <see cref="ModelStateEntry.Errors"/> instance /// that is associated with the specified <paramref name="key"/>. If the maximum number of allowed @@ -215,7 +234,6 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? return false; } - ErrorCount++; AddModelErrorCore(key, exception); return true; } @@ -325,7 +343,6 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? return TryAddModelError(key, exception.Message); } - ErrorCount++; AddModelErrorCore(key, exception); return true; } @@ -383,13 +400,13 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? return false; } - ErrorCount++; var modelState = GetOrAddNode(key); Count += !modelState.IsContainerNode ? 0 : 1; modelState.ValidationState = ModelValidationState.Invalid; modelState.MarkNonContainerNode(); modelState.Errors.Add(errorMessage); + ErrorCount++; return true; } @@ -409,7 +426,7 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? } var item = GetNode(key); - return GetValidity(item) ?? ModelValidationState.Unvalidated; + return GetValidity(item, currentDepth: 0) ?? ModelValidationState.Unvalidated; } /// <summary> @@ -609,11 +626,18 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? var current = _root; if (key.Length > 0) { + var currentDepth = 0; var match = default(MatchResult); do { + if (MaxStateDepth != null && currentDepth >= MaxStateDepth) + { + throw new InvalidOperationException(Resources.FormatModelStateDictionary_MaxModelStateDepth(MaxStateDepth)); + } + var subKey = FindNext(key, ref match); current = current.GetOrAddNode(subKey); + currentDepth++; } while (match.Type != Delimiter.None); @@ -659,9 +683,10 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? return new StringSegment(key, keyStart, index - keyStart); } - private static ModelValidationState? GetValidity(ModelStateNode? node) + private ModelValidationState? GetValidity(ModelStateNode? node, int currentDepth) { - if (node == null) + if (node == null || + (MaxValidationDepth != null && currentDepth >= MaxValidationDepth)) { return null; } @@ -684,9 +709,11 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? if (node.ChildNodes != null) { + currentDepth++; + for (var i = 0; i < node.ChildNodes.Count; i++) { - var entryState = GetValidity(node.ChildNodes[i]); + var entryState = GetValidity(node.ChildNodes[i], currentDepth); if (entryState == ModelValidationState.Unvalidated) { @@ -710,7 +737,6 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? var exception = new TooManyModelErrorsException(Resources.ModelStateDictionary_MaxModelStateErrors); AddModelErrorCore(string.Empty, exception); HasRecordedMaxModelError = true; - ErrorCount++; } } @@ -721,6 +747,8 @@ public class ModelStateDictionary : IReadOnlyDictionary<string, ModelStateEntry? modelState.ValidationState = ModelValidationState.Invalid; modelState.MarkNonContainerNode(); modelState.Errors.Add(exception); + + ErrorCount++; } /// <summary> diff --git a/src/Mvc/Mvc.Abstractions/src/Resources.resx b/src/Mvc/Mvc.Abstractions/src/Resources.resx index c4159218e7..5d38b2bdae 100644 --- a/src/Mvc/Mvc.Abstractions/src/Resources.resx +++ b/src/Mvc/Mvc.Abstractions/src/Resources.resx @@ -180,4 +180,7 @@ <data name="RecordTypeHasValidationOnProperties" xml:space="preserve"> <value>Record type '{0}' has validation metadata defined on property '{1}' that will be ignored. '{1}' is a parameter in the record primary constructor and validation metadata must be associated with the constructor parameter.</value> </data> -</root> + <data name="ModelStateDictionary_MaxModelStateDepth" xml:space="preserve"> + <value>The specified key exceeded the maximum ModelState depth: {0}</value> + </data> +</root>
\ No newline at end of file diff --git a/src/Mvc/Mvc.Abstractions/test/ModelBinding/ModelStateDictionaryTest.cs b/src/Mvc/Mvc.Abstractions/test/ModelBinding/ModelStateDictionaryTest.cs index 6b8e919557..86d39a6fc5 100644 --- a/src/Mvc/Mvc.Abstractions/test/ModelBinding/ModelStateDictionaryTest.cs +++ b/src/Mvc/Mvc.Abstractions/test/ModelBinding/ModelStateDictionaryTest.cs @@ -1599,6 +1599,162 @@ public class ModelStateDictionaryTest Assert.Equal("value1", property.RawValue); } + [Fact] + public void GetFieldValidationState_ReturnsUnvalidated_IfTreeHeightIsGreaterThanLimit() + { + // Arrange + var stackLimit = 5; + var dictionary = new ModelStateDictionary(); + var key = string.Join(".", Enumerable.Repeat("foo", stackLimit + 1)); + dictionary.MaxValidationDepth = stackLimit; + dictionary.MaxStateDepth = null; + dictionary.MarkFieldValid(key); + + // Act + var validationState = dictionary.GetFieldValidationState("foo"); + + // Assert + Assert.Equal(ModelValidationState.Unvalidated, validationState); + } + + [Fact] + public void IsValidProperty_ReturnsTrue_IfTreeHeightIsGreaterThanLimit() + { + // Arrange + var stackLimit = 5; + var dictionary = new ModelStateDictionary(); + var key = string.Join(".", Enumerable.Repeat("foo", stackLimit + 1)); + dictionary.MaxValidationDepth = stackLimit; + dictionary.MaxStateDepth = null; + dictionary.AddModelError(key, "some error"); + + // Act + var isValid = dictionary.IsValid; + var validationState = dictionary.ValidationState; + + // Assert + Assert.True(isValid); + Assert.Equal(ModelValidationState.Valid, validationState); + } + + [Fact] + public void TryAddModelException_Throws_IfKeyHasTooManySegments() + { + // Arrange + var exception = new TestException(); + + var stateDepth = 5; + var dictionary = new ModelStateDictionary(); + var key = string.Join(".", Enumerable.Repeat("foo", stateDepth + 1)); + dictionary.MaxStateDepth = stateDepth; + + // Act + var invalidException = Assert.Throws<InvalidOperationException>(() => dictionary.TryAddModelException(key, exception)); + + // Assert + Assert.Equal( + $"The specified key exceeded the maximum ModelState depth: {dictionary.MaxStateDepth}", + invalidException.Message); + } + + [Fact] + public void TryAddModelError_Throws_IfKeyHasTooManySegments() + { + // Arrange + var stateDepth = 5; + var dictionary = new ModelStateDictionary(); + var key = string.Join(".", Enumerable.Repeat("foo", stateDepth + 1)); + dictionary.MaxStateDepth = stateDepth; + + // Act + var invalidException = Assert.Throws<InvalidOperationException>(() => dictionary.TryAddModelError(key, "errorMessage")); + + // Assert + Assert.Equal( + $"The specified key exceeded the maximum ModelState depth: {dictionary.MaxStateDepth}", + invalidException.Message); + } + + [Fact] + public void SetModelValue_Throws_IfKeyHasTooManySegments() + { + var stateDepth = 5; + var dictionary = new ModelStateDictionary(); + var key = string.Join(".", Enumerable.Repeat("foo", stateDepth + 1)); + dictionary.MaxStateDepth = stateDepth; + + // Act + var invalidException = Assert.Throws<InvalidOperationException>(() => dictionary.SetModelValue(key, string.Empty, string.Empty)); + + // Assert + Assert.Equal( + $"The specified key exceeded the maximum ModelState depth: {dictionary.MaxStateDepth}", + invalidException.Message); + } + + [Fact] + public void MarkFieldValid_Throws_IfKeyHasTooManySegments() + { + // Arrange + var stateDepth = 5; + var source = new ModelStateDictionary(); + var key = string.Join(".", Enumerable.Repeat("foo", stateDepth + 1)); + source.MaxStateDepth = stateDepth; + + // Act + var exception = Assert.Throws<InvalidOperationException>(() => source.MarkFieldValid(key)); + + // Assert + Assert.Equal( + $"The specified key exceeded the maximum ModelState depth: {source.MaxStateDepth}", + exception.Message); + } + + [Fact] + public void MarkFieldSkipped_Throws_IfKeyHasTooManySegments() + { + // Arrange + var stateDepth = 5; + var source = new ModelStateDictionary(); + var key = string.Join(".", Enumerable.Repeat("foo", stateDepth + 1)); + source.MaxStateDepth = stateDepth; + + // Act + var exception = Assert.Throws<InvalidOperationException>(() => source.MarkFieldSkipped(key)); + + // Assert + Assert.Equal( + $"The specified key exceeded the maximum ModelState depth: {source.MaxStateDepth}", + exception.Message); + } + + [Fact] + public void Constructor_SetsDefaultRecursionDepth() + { + // Arrange && Act + var dictionary = new ModelStateDictionary(); + + // Assert + Assert.Equal(ModelStateDictionary.DefaultMaxRecursionDepth, dictionary.MaxValidationDepth); + Assert.Equal(ModelStateDictionary.DefaultMaxRecursionDepth, dictionary.MaxStateDepth); + } + + [Fact] + public void CopyConstructor_PreservesRecursionDepth() + { + // Arrange + var dictionary = new ModelStateDictionary(); + dictionary.MaxValidationDepth = 5; + dictionary.MaxStateDepth = 4; + + // Act + var newDictionary = new ModelStateDictionary(dictionary); + + // Assert + Assert.Equal(dictionary.MaxValidationDepth, newDictionary.MaxValidationDepth); + Assert.Equal(dictionary.MaxStateDepth, newDictionary.MaxStateDepth); + } + private DefaultBindingMetadataProvider CreateBindingMetadataProvider() => new DefaultBindingMetadataProvider(); diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerProvider.cs b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerProvider.cs index 0da5397b8a..e3f03f88aa 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerProvider.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerProvider.cs @@ -18,6 +18,8 @@ internal sealed class ControllerActionInvokerProvider : IActionInvokerProvider private readonly ControllerActionInvokerCache _controllerActionInvokerCache; private readonly IReadOnlyList<IValueProviderFactory> _valueProviderFactories; private readonly int _maxModelValidationErrors; + private readonly int? _maxValidationDepth; + private readonly int _maxModelBindingRecursionDepth; private readonly ILogger _logger; private readonly DiagnosticListener _diagnosticListener; private readonly IActionResultTypeMapper _mapper; @@ -44,6 +46,8 @@ internal sealed class ControllerActionInvokerProvider : IActionInvokerProvider _controllerActionInvokerCache = controllerActionInvokerCache; _valueProviderFactories = optionsAccessor.Value.ValueProviderFactories.ToArray(); _maxModelValidationErrors = optionsAccessor.Value.MaxModelValidationErrors; + _maxValidationDepth = optionsAccessor.Value.MaxValidationDepth; + _maxModelBindingRecursionDepth = optionsAccessor.Value.MaxModelBindingRecursionDepth; _logger = loggerFactory.CreateLogger<ControllerActionInvoker>(); _diagnosticListener = diagnosticListener; _mapper = mapper; @@ -68,6 +72,8 @@ internal sealed class ControllerActionInvokerProvider : IActionInvokerProvider ValueProviderFactories = new CopyOnWriteList<IValueProviderFactory>(_valueProviderFactories) }; controllerContext.ModelState.MaxAllowedErrors = _maxModelValidationErrors; + controllerContext.ModelState.MaxValidationDepth = _maxValidationDepth; + controllerContext.ModelState.MaxStateDepth = _maxModelBindingRecursionDepth; var (cacheEntry, filters) = _controllerActionInvokerCache.GetCachedResult(controllerContext); diff --git a/src/Mvc/Mvc.Core/src/Routing/ControllerRequestDelegateFactory.cs b/src/Mvc/Mvc.Core/src/Routing/ControllerRequestDelegateFactory.cs index 2e52948978..f31448e787 100644 --- a/src/Mvc/Mvc.Core/src/Routing/ControllerRequestDelegateFactory.cs +++ b/src/Mvc/Mvc.Core/src/Routing/ControllerRequestDelegateFactory.cs @@ -19,6 +19,8 @@ internal sealed class ControllerRequestDelegateFactory : IRequestDelegateFactory private readonly ControllerActionInvokerCache _controllerActionInvokerCache; private readonly IReadOnlyList<IValueProviderFactory> _valueProviderFactories; private readonly int _maxModelValidationErrors; + private readonly int? _maxValidationDepth; + private readonly int _maxModelBindingRecursionDepth; private readonly ILogger _logger; private readonly DiagnosticListener _diagnosticListener; private readonly IActionResultTypeMapper _mapper; @@ -46,6 +48,8 @@ internal sealed class ControllerRequestDelegateFactory : IRequestDelegateFactory _controllerActionInvokerCache = controllerActionInvokerCache; _valueProviderFactories = optionsAccessor.Value.ValueProviderFactories.ToArray(); _maxModelValidationErrors = optionsAccessor.Value.MaxModelValidationErrors; + _maxValidationDepth = optionsAccessor.Value.MaxValidationDepth; + _maxModelBindingRecursionDepth = optionsAccessor.Value.MaxModelBindingRecursionDepth; _enableActionInvokers = optionsAccessor.Value.EnableActionInvokers; _logger = loggerFactory.CreateLogger<ControllerActionInvoker>(); _diagnosticListener = diagnosticListener; @@ -82,6 +86,8 @@ internal sealed class ControllerRequestDelegateFactory : IRequestDelegateFactory }; controllerContext.ModelState.MaxAllowedErrors = _maxModelValidationErrors; + controllerContext.ModelState.MaxValidationDepth = _maxValidationDepth; + controllerContext.ModelState.MaxStateDepth = _maxModelBindingRecursionDepth; var (cacheEntry, filters) = _controllerActionInvokerCache.GetCachedResult(controllerContext); diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/ControllerActionInvokerProviderTest.cs b/src/Mvc/Mvc.Core/test/Infrastructure/ControllerActionInvokerProviderTest.cs new file mode 100644 index 0000000000..374f263899 --- /dev/null +++ b/src/Mvc/Mvc.Core/test/Infrastructure/ControllerActionInvokerProviderTest.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Reflection; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc.Abstractions; +using Microsoft.AspNetCore.Mvc.Controllers; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.AspNetCore.Mvc.ModelBinding; +using Microsoft.AspNetCore.Mvc.ModelBinding.Validation; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Moq; + +namespace Microsoft.AspNetCore.Mvc.Infrastructure; + +public class ControllerActionInvokerProviderTest +{ + [Fact] + public void OnExecuting_ConfiguresModelState_WithMvcOptions() + { + // Arrange + var provider = CreateInvokerProvider(new MvcOptions() { MaxValidationDepth = 1, MaxModelBindingRecursionDepth = 2, MaxModelValidationErrors = 3 }); + + var context = new ActionInvokerProviderContext(new ActionContext() + { + ActionDescriptor = GetControllerActionDescriptor(), + HttpContext = new DefaultHttpContext(), + RouteData = new RouteData(), + }); + + // Act + provider.OnProvidersExecuting(context); + + // Assert + var invoker = Assert.IsType<ControllerActionInvoker>(context.Result); + Assert.Equal(1, invoker.ControllerContext.ModelState.MaxValidationDepth); + Assert.Equal(2, invoker.ControllerContext.ModelState.MaxStateDepth); + Assert.Equal(3, invoker.ControllerContext.ModelState.MaxAllowedErrors); + + } + + private static ControllerActionDescriptor GetControllerActionDescriptor() + { + var method = typeof(TestActions).GetMethod(nameof(TestActions.GetAction)); + var actionDescriptor = new ControllerActionDescriptor + { + MethodInfo = method, + FilterDescriptors = new List<FilterDescriptor>(), + ControllerTypeInfo = typeof(TestActions).GetTypeInfo(), + }; + + foreach (var filterAttribute in method.GetCustomAttributes().OfType<IFilterMetadata>()) + { + actionDescriptor.FilterDescriptors.Add(new FilterDescriptor(filterAttribute, FilterScope.Action)); + } + + return actionDescriptor; + } + + private static ControllerActionInvokerProvider CreateInvokerProvider(MvcOptions mvcOptions = null) + { + var modelMetadataProvider = TestModelMetadataProvider.CreateDefaultProvider(); + var modelBinderFactory = TestModelBinderFactory.CreateDefault(); + mvcOptions ??= new MvcOptions(); + + var parameterBinder = new ParameterBinder( + modelMetadataProvider, + TestModelBinderFactory.CreateDefault(), + Mock.Of<IObjectModelValidator>(), + Options.Create(mvcOptions), + NullLoggerFactory.Instance); + + var cache = new ControllerActionInvokerCache( + parameterBinder, + modelBinderFactory, + modelMetadataProvider, + new[] { new DefaultFilterProvider() }, + Mock.Of<IControllerFactoryProvider>(), + Options.Create(mvcOptions)); + + return new( + cache, + Options.Create(mvcOptions), + NullLoggerFactory.Instance, + new DiagnosticListener("Microsoft.AspNetCore"), + new ActionResultTypeMapper()); + } + + private class TestActions : Controller + { + public IActionResult GetAction() => new OkResult(); + } +} |