Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/KhronosGroup/SPIRV-Tools.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSpencer Fricke <spencerfricke@gmail.com>2022-10-24 19:45:08 +0300
committerGitHub <noreply@github.com>2022-10-24 19:45:08 +0300
commit0ebf830572133cc0b95e39990ae0bb0767aa52fe (patch)
tree9cc3b7eece30f62d494910be0a55ab2c847c648f
parenteb113f0fdfff8efc114953bdabf1738db681ad8d (diff)
spirv-val: Add OpPtrAccessChain Base checks (#4965)
-rw-r--r--source/val/validate_memory.cpp46
-rw-r--r--test/opt/eliminate_dead_member_test.cpp1
-rw-r--r--test/val/val_id_test.cpp25
-rw-r--r--test/val/val_memory_test.cpp214
4 files changed, 279 insertions, 7 deletions
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 074bdb88a..8a66beef7 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -1410,7 +1410,51 @@ spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
<< "VariablePointers or VariablePointersStorageBuffer";
}
}
- return ValidateAccessChain(_, inst);
+
+ // Need to call first, will make sure Base is a valid ID
+ if (auto error = ValidateAccessChain(_, inst)) return error;
+
+ const auto base_id = inst->GetOperandAs<uint32_t>(2);
+ const auto base = _.FindDef(base_id);
+ const auto base_type = _.FindDef(base->type_id());
+ const auto base_type_storage_class = base_type->word(2);
+
+ if (_.HasCapability(SpvCapabilityShader) &&
+ (base_type_storage_class == SpvStorageClassUniform ||
+ base_type_storage_class == SpvStorageClassStorageBuffer ||
+ base_type_storage_class == SpvStorageClassPhysicalStorageBuffer ||
+ base_type_storage_class == SpvStorageClassPushConstant ||
+ (_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayoutKHR) &&
+ base_type_storage_class == SpvStorageClassWorkgroup)) &&
+ !_.HasDecoration(base_type->id(), SpvDecorationArrayStride)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "OpPtrAccessChain must have a Base whose type is decorated "
+ "with ArrayStride";
+ }
+
+ if (spvIsVulkanEnv(_.context()->target_env)) {
+ if (base_type_storage_class == SpvStorageClassWorkgroup) {
+ if (!_.HasCapability(SpvCapabilityVariablePointers)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "OpPtrAccessChain Base operand pointing to Workgroup "
+ "storage class must use VariablePointers capability";
+ }
+ } else if (base_type_storage_class == SpvStorageClassStorageBuffer) {
+ if (!_.features().variable_pointers) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "OpPtrAccessChain Base operand pointing to StorageBuffer "
+ "storage class must use VariablePointers or "
+ "VariablePointersStorageBuffer capability";
+ }
+ } else if (base_type_storage_class !=
+ SpvStorageClassPhysicalStorageBuffer) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "OpPtrAccessChain Base operand must point to Workgroup, "
+ "StorageBuffer, or PhysicalStorageBuffer storage class";
+ }
+ }
+
+ return SPV_SUCCESS;
}
spv_result_t ValidateArrayLength(ValidationState_t& state,
diff --git a/test/opt/eliminate_dead_member_test.cpp b/test/opt/eliminate_dead_member_test.cpp
index e277999e4..4438f3d86 100644
--- a/test/opt/eliminate_dead_member_test.cpp
+++ b/test/opt/eliminate_dead_member_test.cpp
@@ -978,6 +978,7 @@ TEST_F(EliminateDeadMemberTest, RemoveMemberPtrAccessChain) {
OpMemberDecorate %type__Globals 1 Offset 4
OpMemberDecorate %type__Globals 2 Offset 16
OpDecorate %type__Globals Block
+ OpDecorate %_ptr_Uniform_type__Globals ArrayStride 8
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index dd040b33e..a457980dc 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -3981,8 +3981,11 @@ OpFunctionEnd
TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesGood) {
const std::string instr = GetParam();
const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : "";
+ const std::string arrayStride =
+ " OpDecorate %_ptr_Uniform_deep_struct ArrayStride 8 ";
int depth = 255;
- std::string header = kGLSL450MemoryModel + kDeeplyNestedStructureSetup;
+ std::string header =
+ kGLSL450MemoryModel + arrayStride + kDeeplyNestedStructureSetup;
header.erase(header.find("%func"));
std::ostringstream spirv;
spirv << header << "\n";
@@ -4044,8 +4047,11 @@ TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesBad) {
TEST_P(AccessChainInstructionTest, CustomizedAccessChainTooManyIndexesGood) {
const std::string instr = GetParam();
const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : "";
+ const std::string arrayStride =
+ " OpDecorate %_ptr_Uniform_deep_struct ArrayStride 8 ";
int depth = 10;
- std::string header = kGLSL450MemoryModel + kDeeplyNestedStructureSetup;
+ std::string header =
+ kGLSL450MemoryModel + arrayStride + kDeeplyNestedStructureSetup;
header.erase(header.find("%func"));
std::ostringstream spirv;
spirv << header << "\n";
@@ -4217,8 +4223,11 @@ TEST_P(AccessChainInstructionTest, AccessChainIndexIntoAllTypesGood) {
// 0 will select the element at the index 0 of the vector. (which is a float).
const std::string instr = GetParam();
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
+ const std::string arrayStride =
+ " OpDecorate %_ptr_Uniform_blockName ArrayStride 8 ";
std::ostringstream spirv;
- spirv << kGLSL450MemoryModel << kDeeplyNestedStructureSetup << std::endl;
+ spirv << kGLSL450MemoryModel << arrayStride << kDeeplyNestedStructureSetup
+ << std::endl;
spirv << "%ss = " << instr << " %_ptr_Uniform_struct_s %blockName_var "
<< elem << "%int_0" << std::endl;
spirv << "%sa = " << instr << " %_ptr_Uniform_array5_mat4x3 %blockName_var "
@@ -4241,9 +4250,12 @@ OpFunctionEnd
TEST_P(AccessChainInstructionTest, AccessChainIndexIntoRuntimeArrayGood) {
const std::string instr = GetParam();
const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
- std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
-%runtime_arr_entry = )" +
- instr + R"( %_ptr_Uniform_float %blockName_var )" + elem +
+ const std::string arrayStride =
+ " OpDecorate %_ptr_Uniform_blockName ArrayStride 8 ";
+ std::string spirv = kGLSL450MemoryModel + arrayStride +
+ kDeeplyNestedStructureSetup + R"(
+%runtime_arr_entry = )" + instr +
+ R"( %_ptr_Uniform_float %blockName_var )" + elem +
R"(%int_2 %int_0
OpReturn
OpFunctionEnd
@@ -5631,6 +5643,7 @@ TEST_P(ValidateIdWithMessage, StgBufOpPtrAccessChainGood) {
OpExtension "SPV_KHR_variable_pointers"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 ""
+ OpDecorate %ptr ArrayStride 8
%int = OpTypeInt 32 0
%int_2 = OpConstant %int 2
%int_4 = OpConstant %int 4
diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp
index 4299eda9d..780aeedc8 100644
--- a/test/val/val_memory_test.cpp
+++ b/test/val/val_memory_test.cpp
@@ -4712,6 +4712,220 @@ OpFunctionEnd
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateMemory, PtrAccessChainArrayStrideBad) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability VariablePointersStorageBuffer
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_KHR_variable_pointers"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "foo" %var
+ OpExecutionMode %main LocalSize 1 1 1
+ OpDecorate %var DescriptorSet 0
+ OpDecorate %var Binding 0
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ %ptr = OpTypePointer StorageBuffer %uint
+ %void = OpTypeVoid
+ %func = OpTypeFunction %void
+ %var = OpVariable %ptr StorageBuffer
+ %main = OpFunction %void None %func
+ %label = OpLabel
+ %access = OpAccessChain %ptr %var
+ %ptr_access = OpPtrAccessChain %ptr %access %uint_1
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_5);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA,
+ ValidateInstructions(SPV_ENV_UNIVERSAL_1_5));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("OpPtrAccessChain must have a Base whose type is "
+ "decorated with ArrayStride"));
+}
+
+TEST_F(ValidateMemory, PtrAccessChainArrayStrideSuccess) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability VariablePointersStorageBuffer
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_KHR_variable_pointers"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "foo" %var
+ OpExecutionMode %main LocalSize 1 1 1
+ OpDecorate %var DescriptorSet 0
+ OpDecorate %var Binding 00
+ OpDecorate %ptr ArrayStride 4
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ %ptr = OpTypePointer StorageBuffer %uint
+ %void = OpTypeVoid
+ %func = OpTypeFunction %void
+ %var = OpVariable %ptr StorageBuffer
+ %main = OpFunction %void None %func
+ %label = OpLabel
+ %access = OpAccessChain %ptr %var
+ %ptr_access = OpPtrAccessChain %ptr %access %uint_1
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_5);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_5));
+}
+
+TEST_F(ValidateMemory, VulkanPtrAccessChainStorageBufferSuccess) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability VariablePointersStorageBuffer
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_KHR_variable_pointers"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "foo" %var
+ OpExecutionMode %main LocalSize 1 1 1
+ OpDecorate %_runtimearr_uint ArrayStride 4
+ OpMemberDecorate %_struct_10 0 Offset 0
+ OpDecorate %_struct_10 Block
+ OpDecorate %var DescriptorSet 0
+ OpDecorate %var Binding 0
+ OpDecorate %_ptr_StorageBuffer_uint ArrayStride 4
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+%_runtimearr_uint = OpTypeRuntimeArray %uint
+ %_struct_10 = OpTypeStruct %_runtimearr_uint
+%_ptr_StorageBuffer__struct_10 = OpTypePointer StorageBuffer %_struct_10
+%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+ %void = OpTypeVoid
+ %func2 = OpTypeFunction %void %_ptr_StorageBuffer_uint
+ %func1 = OpTypeFunction %void
+ %var = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer
+ %called = OpFunction %void None %func2
+ %param = OpFunctionParameter %_ptr_StorageBuffer_uint
+ %label2 = OpLabel
+ %ptr_access = OpPtrAccessChain %_ptr_StorageBuffer_uint %param %uint_1
+ OpReturn
+ OpFunctionEnd
+ %main = OpFunction %void None %func1
+ %label1 = OpLabel
+ %access = OpAccessChain %_ptr_StorageBuffer_uint %var %uint_0 %uint_0
+ %call = OpFunctionCall %void %called %access
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateMemory, VulkanPtrAccessChainStorageBufferCapability) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability PhysicalStorageBufferAddresses
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_KHR_variable_pointers"
+ OpMemoryModel PhysicalStorageBuffer64 GLSL450
+ OpEntryPoint GLCompute %main "foo" %var
+ OpExecutionMode %main LocalSize 1 1 1
+ OpDecorate %_runtimearr_uint ArrayStride 4
+ OpMemberDecorate %_struct_10 0 Offset 0
+ OpDecorate %_struct_10 Block
+ OpDecorate %var DescriptorSet 0
+ OpDecorate %var Binding 0
+ OpDecorate %_ptr_StorageBuffer_uint ArrayStride 4
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+%_runtimearr_uint = OpTypeRuntimeArray %uint
+ %_struct_10 = OpTypeStruct %_runtimearr_uint
+%_ptr_StorageBuffer__struct_10 = OpTypePointer StorageBuffer %_struct_10
+%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+ %void = OpTypeVoid
+ %func = OpTypeFunction %void
+ %var = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer
+ %main = OpFunction %void None %func
+ %label = OpLabel
+ %access = OpAccessChain %_ptr_StorageBuffer_uint %var %uint_0 %uint_0
+ %ptr_access = OpPtrAccessChain %_ptr_StorageBuffer_uint %access %uint_1
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("OpPtrAccessChain Base operand pointing to "
+ "StorageBuffer storage class must use VariablePointers "
+ "or VariablePointersStorageBuffer capability"));
+}
+
+TEST_F(ValidateMemory, VulkanPtrAccessChainWorkgroupCapability) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability VariablePointersStorageBuffer
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_KHR_variable_pointers"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "foo" %var
+ OpExecutionMode %main LocalSize 1 1 1
+ OpDecorate %_ptr_Workgroup_uint ArrayStride 4
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+%_arr_uint = OpTypeArray %uint %uint_1
+%_ptr_Workgroup__arr_uint = OpTypePointer Workgroup %_arr_uint
+%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+ %void = OpTypeVoid
+ %func = OpTypeFunction %void
+ %var = OpVariable %_ptr_Workgroup__arr_uint Workgroup
+ %main = OpFunction %void None %func
+ %label = OpLabel
+ %access = OpAccessChain %_ptr_Workgroup_uint %var %uint_0
+ %ptr_access = OpPtrAccessChain %_ptr_Workgroup_uint %access %uint_1
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("OpPtrAccessChain Base operand pointing to Workgroup "
+ "storage class must use VariablePointers capability"));
+}
+
+TEST_F(ValidateMemory, VulkanPtrAccessChainWorkgroupNoArrayStrideSuccess) {
+ const std::string spirv = R"(
+ OpCapability Shader
+ OpCapability VariablePointers
+ OpExtension "SPV_KHR_storage_buffer_storage_class"
+ OpExtension "SPV_KHR_variable_pointers"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "foo" %var
+ OpExecutionMode %main LocalSize 1 1 1
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+%_arr_uint = OpTypeArray %uint %uint_1
+%_ptr_Workgroup__arr_uint = OpTypePointer Workgroup %_arr_uint
+%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+ %void = OpTypeVoid
+ %func = OpTypeFunction %void
+ %var = OpVariable %_ptr_Workgroup__arr_uint Workgroup
+ %main = OpFunction %void None %func
+ %label = OpLabel
+ %access = OpAccessChain %_ptr_Workgroup_uint %var %uint_0
+ %ptr_access = OpPtrAccessChain %_ptr_Workgroup_uint %access %uint_1
+ OpReturn
+ OpFunctionEnd
+)";
+
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
} // namespace
} // namespace val
} // namespace spvtools