diff options
author | Spencer Fricke <spencerfricke@gmail.com> | 2022-10-24 19:45:08 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-24 19:45:08 +0300 |
commit | 0ebf830572133cc0b95e39990ae0bb0767aa52fe (patch) | |
tree | 9cc3b7eece30f62d494910be0a55ab2c847c648f | |
parent | eb113f0fdfff8efc114953bdabf1738db681ad8d (diff) |
spirv-val: Add OpPtrAccessChain Base checks (#4965)
-rw-r--r-- | source/val/validate_memory.cpp | 46 | ||||
-rw-r--r-- | test/opt/eliminate_dead_member_test.cpp | 1 | ||||
-rw-r--r-- | test/val/val_id_test.cpp | 25 | ||||
-rw-r--r-- | test/val/val_memory_test.cpp | 214 |
4 files changed, 279 insertions, 7 deletions
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp index 074bdb88a..8a66beef7 100644 --- a/source/val/validate_memory.cpp +++ b/source/val/validate_memory.cpp @@ -1410,7 +1410,51 @@ spv_result_t ValidatePtrAccessChain(ValidationState_t& _, << "VariablePointers or VariablePointersStorageBuffer"; } } - return ValidateAccessChain(_, inst); + + // Need to call first, will make sure Base is a valid ID + if (auto error = ValidateAccessChain(_, inst)) return error; + + const auto base_id = inst->GetOperandAs<uint32_t>(2); + const auto base = _.FindDef(base_id); + const auto base_type = _.FindDef(base->type_id()); + const auto base_type_storage_class = base_type->word(2); + + if (_.HasCapability(SpvCapabilityShader) && + (base_type_storage_class == SpvStorageClassUniform || + base_type_storage_class == SpvStorageClassStorageBuffer || + base_type_storage_class == SpvStorageClassPhysicalStorageBuffer || + base_type_storage_class == SpvStorageClassPushConstant || + (_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayoutKHR) && + base_type_storage_class == SpvStorageClassWorkgroup)) && + !_.HasDecoration(base_type->id(), SpvDecorationArrayStride)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "OpPtrAccessChain must have a Base whose type is decorated " + "with ArrayStride"; + } + + if (spvIsVulkanEnv(_.context()->target_env)) { + if (base_type_storage_class == SpvStorageClassWorkgroup) { + if (!_.HasCapability(SpvCapabilityVariablePointers)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "OpPtrAccessChain Base operand pointing to Workgroup " + "storage class must use VariablePointers capability"; + } + } else if (base_type_storage_class == SpvStorageClassStorageBuffer) { + if (!_.features().variable_pointers) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "OpPtrAccessChain Base operand pointing to StorageBuffer " + "storage class must use VariablePointers or " + "VariablePointersStorageBuffer capability"; + } + } else if (base_type_storage_class != + SpvStorageClassPhysicalStorageBuffer) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "OpPtrAccessChain Base operand must point to Workgroup, " + "StorageBuffer, or PhysicalStorageBuffer storage class"; + } + } + + return SPV_SUCCESS; } spv_result_t ValidateArrayLength(ValidationState_t& state, diff --git a/test/opt/eliminate_dead_member_test.cpp b/test/opt/eliminate_dead_member_test.cpp index e277999e4..4438f3d86 100644 --- a/test/opt/eliminate_dead_member_test.cpp +++ b/test/opt/eliminate_dead_member_test.cpp @@ -978,6 +978,7 @@ TEST_F(EliminateDeadMemberTest, RemoveMemberPtrAccessChain) { OpMemberDecorate %type__Globals 1 Offset 4 OpMemberDecorate %type__Globals 2 Offset 16 OpDecorate %type__Globals Block + OpDecorate %_ptr_Uniform_type__Globals ArrayStride 8 %uint = OpTypeInt 32 0 %uint_0 = OpConstant %uint 0 %uint_1 = OpConstant %uint 1 diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp index dd040b33e..a457980dc 100644 --- a/test/val/val_id_test.cpp +++ b/test/val/val_id_test.cpp @@ -3981,8 +3981,11 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesGood) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : ""; + const std::string arrayStride = + " OpDecorate %_ptr_Uniform_deep_struct ArrayStride 8 "; int depth = 255; - std::string header = kGLSL450MemoryModel + kDeeplyNestedStructureSetup; + std::string header = + kGLSL450MemoryModel + arrayStride + kDeeplyNestedStructureSetup; header.erase(header.find("%func")); std::ostringstream spirv; spirv << header << "\n"; @@ -4044,8 +4047,11 @@ TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesBad) { TEST_P(AccessChainInstructionTest, CustomizedAccessChainTooManyIndexesGood) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : ""; + const std::string arrayStride = + " OpDecorate %_ptr_Uniform_deep_struct ArrayStride 8 "; int depth = 10; - std::string header = kGLSL450MemoryModel + kDeeplyNestedStructureSetup; + std::string header = + kGLSL450MemoryModel + arrayStride + kDeeplyNestedStructureSetup; header.erase(header.find("%func")); std::ostringstream spirv; spirv << header << "\n"; @@ -4217,8 +4223,11 @@ TEST_P(AccessChainInstructionTest, AccessChainIndexIntoAllTypesGood) { // 0 will select the element at the index 0 of the vector. (which is a float). const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; + const std::string arrayStride = + " OpDecorate %_ptr_Uniform_blockName ArrayStride 8 "; std::ostringstream spirv; - spirv << kGLSL450MemoryModel << kDeeplyNestedStructureSetup << std::endl; + spirv << kGLSL450MemoryModel << arrayStride << kDeeplyNestedStructureSetup + << std::endl; spirv << "%ss = " << instr << " %_ptr_Uniform_struct_s %blockName_var " << elem << "%int_0" << std::endl; spirv << "%sa = " << instr << " %_ptr_Uniform_array5_mat4x3 %blockName_var " @@ -4241,9 +4250,12 @@ OpFunctionEnd TEST_P(AccessChainInstructionTest, AccessChainIndexIntoRuntimeArrayGood) { const std::string instr = GetParam(); const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : ""; - std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"( -%runtime_arr_entry = )" + - instr + R"( %_ptr_Uniform_float %blockName_var )" + elem + + const std::string arrayStride = + " OpDecorate %_ptr_Uniform_blockName ArrayStride 8 "; + std::string spirv = kGLSL450MemoryModel + arrayStride + + kDeeplyNestedStructureSetup + R"( +%runtime_arr_entry = )" + instr + + R"( %_ptr_Uniform_float %blockName_var )" + elem + R"(%int_2 %int_0 OpReturn OpFunctionEnd @@ -5631,6 +5643,7 @@ TEST_P(ValidateIdWithMessage, StgBufOpPtrAccessChainGood) { OpExtension "SPV_KHR_variable_pointers" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %3 "" + OpDecorate %ptr ArrayStride 8 %int = OpTypeInt 32 0 %int_2 = OpConstant %int 2 %int_4 = OpConstant %int 4 diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp index 4299eda9d..780aeedc8 100644 --- a/test/val/val_memory_test.cpp +++ b/test/val/val_memory_test.cpp @@ -4712,6 +4712,220 @@ OpFunctionEnd EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateMemory, PtrAccessChainArrayStrideBad) { + const std::string spirv = R"( + OpCapability Shader + OpCapability VariablePointersStorageBuffer + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_KHR_variable_pointers" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "foo" %var + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %var DescriptorSet 0 + OpDecorate %var Binding 0 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %ptr = OpTypePointer StorageBuffer %uint + %void = OpTypeVoid + %func = OpTypeFunction %void + %var = OpVariable %ptr StorageBuffer + %main = OpFunction %void None %func + %label = OpLabel + %access = OpAccessChain %ptr %var + %ptr_access = OpPtrAccessChain %ptr %access %uint_1 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_5); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, + ValidateInstructions(SPV_ENV_UNIVERSAL_1_5)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPtrAccessChain must have a Base whose type is " + "decorated with ArrayStride")); +} + +TEST_F(ValidateMemory, PtrAccessChainArrayStrideSuccess) { + const std::string spirv = R"( + OpCapability Shader + OpCapability VariablePointersStorageBuffer + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_KHR_variable_pointers" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "foo" %var + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %var DescriptorSet 0 + OpDecorate %var Binding 00 + OpDecorate %ptr ArrayStride 4 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %ptr = OpTypePointer StorageBuffer %uint + %void = OpTypeVoid + %func = OpTypeFunction %void + %var = OpVariable %ptr StorageBuffer + %main = OpFunction %void None %func + %label = OpLabel + %access = OpAccessChain %ptr %var + %ptr_access = OpPtrAccessChain %ptr %access %uint_1 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_5); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_5)); +} + +TEST_F(ValidateMemory, VulkanPtrAccessChainStorageBufferSuccess) { + const std::string spirv = R"( + OpCapability Shader + OpCapability VariablePointersStorageBuffer + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_KHR_variable_pointers" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "foo" %var + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %_runtimearr_uint ArrayStride 4 + OpMemberDecorate %_struct_10 0 Offset 0 + OpDecorate %_struct_10 Block + OpDecorate %var DescriptorSet 0 + OpDecorate %var Binding 0 + OpDecorate %_ptr_StorageBuffer_uint ArrayStride 4 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 +%_runtimearr_uint = OpTypeRuntimeArray %uint + %_struct_10 = OpTypeStruct %_runtimearr_uint +%_ptr_StorageBuffer__struct_10 = OpTypePointer StorageBuffer %_struct_10 +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint + %void = OpTypeVoid + %func2 = OpTypeFunction %void %_ptr_StorageBuffer_uint + %func1 = OpTypeFunction %void + %var = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer + %called = OpFunction %void None %func2 + %param = OpFunctionParameter %_ptr_StorageBuffer_uint + %label2 = OpLabel + %ptr_access = OpPtrAccessChain %_ptr_StorageBuffer_uint %param %uint_1 + OpReturn + OpFunctionEnd + %main = OpFunction %void None %func1 + %label1 = OpLabel + %access = OpAccessChain %_ptr_StorageBuffer_uint %var %uint_0 %uint_0 + %call = OpFunctionCall %void %called %access + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); +} + +TEST_F(ValidateMemory, VulkanPtrAccessChainStorageBufferCapability) { + const std::string spirv = R"( + OpCapability Shader + OpCapability PhysicalStorageBufferAddresses + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_KHR_variable_pointers" + OpMemoryModel PhysicalStorageBuffer64 GLSL450 + OpEntryPoint GLCompute %main "foo" %var + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %_runtimearr_uint ArrayStride 4 + OpMemberDecorate %_struct_10 0 Offset 0 + OpDecorate %_struct_10 Block + OpDecorate %var DescriptorSet 0 + OpDecorate %var Binding 0 + OpDecorate %_ptr_StorageBuffer_uint ArrayStride 4 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 +%_runtimearr_uint = OpTypeRuntimeArray %uint + %_struct_10 = OpTypeStruct %_runtimearr_uint +%_ptr_StorageBuffer__struct_10 = OpTypePointer StorageBuffer %_struct_10 +%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint + %void = OpTypeVoid + %func = OpTypeFunction %void + %var = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer + %main = OpFunction %void None %func + %label = OpLabel + %access = OpAccessChain %_ptr_StorageBuffer_uint %var %uint_0 %uint_0 + %ptr_access = OpPtrAccessChain %_ptr_StorageBuffer_uint %access %uint_1 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_2)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPtrAccessChain Base operand pointing to " + "StorageBuffer storage class must use VariablePointers " + "or VariablePointersStorageBuffer capability")); +} + +TEST_F(ValidateMemory, VulkanPtrAccessChainWorkgroupCapability) { + const std::string spirv = R"( + OpCapability Shader + OpCapability VariablePointersStorageBuffer + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_KHR_variable_pointers" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "foo" %var + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %_ptr_Workgroup_uint ArrayStride 4 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 +%_arr_uint = OpTypeArray %uint %uint_1 +%_ptr_Workgroup__arr_uint = OpTypePointer Workgroup %_arr_uint +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint + %void = OpTypeVoid + %func = OpTypeFunction %void + %var = OpVariable %_ptr_Workgroup__arr_uint Workgroup + %main = OpFunction %void None %func + %label = OpLabel + %access = OpAccessChain %_ptr_Workgroup_uint %var %uint_0 + %ptr_access = OpPtrAccessChain %_ptr_Workgroup_uint %access %uint_1 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_2)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("OpPtrAccessChain Base operand pointing to Workgroup " + "storage class must use VariablePointers capability")); +} + +TEST_F(ValidateMemory, VulkanPtrAccessChainWorkgroupNoArrayStrideSuccess) { + const std::string spirv = R"( + OpCapability Shader + OpCapability VariablePointers + OpExtension "SPV_KHR_storage_buffer_storage_class" + OpExtension "SPV_KHR_variable_pointers" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "foo" %var + OpExecutionMode %main LocalSize 1 1 1 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 +%_arr_uint = OpTypeArray %uint %uint_1 +%_ptr_Workgroup__arr_uint = OpTypePointer Workgroup %_arr_uint +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint + %void = OpTypeVoid + %func = OpTypeFunction %void + %var = OpVariable %_ptr_Workgroup__arr_uint Workgroup + %main = OpFunction %void None %func + %label = OpLabel + %access = OpAccessChain %_ptr_Workgroup_uint %var %uint_0 + %ptr_access = OpPtrAccessChain %_ptr_Workgroup_uint %access %uint_1 + OpReturn + OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); +} + } // namespace } // namespace val } // namespace spvtools |