Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/KhronosGroup/SPIRV-Tools.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSpencer Fricke <spencerfricke@gmail.com>2022-09-23 15:45:11 +0300
committerGitHub <noreply@github.com>2022-09-23 15:45:11 +0300
commitddbee48f85e3cb977695835de364e6f05e82dd62 (patch)
tree1fdd8acd372a9c1f1dd0976607b58c267f49918b
parentf98473ceeb1d33700d01e20910433583e5256030 (diff)
spirv-opt: Fix stacked CompositeExtract constant folds (#4932)
This was spotted in the Validation Layers where OpSpecConstantOp %x CompositeExtract %y 0 was being folded to a constant, but anything that was using it wasn't recognizing it as a constant, the simple fix was to add a const_mgr->MapInst(new_const_inst); so the next instruction knew it was a const
-rw-r--r--source/opt/fold.cpp3
-rw-r--r--source/opt/fold_spec_constant_op_and_composite_pass.cpp34
-rw-r--r--test/opt/fold_spec_const_op_composite_test.cpp205
3 files changed, 221 insertions, 21 deletions
diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp
index b903da6a2..315741ad7 100644
--- a/source/opt/fold.cpp
+++ b/source/opt/fold.cpp
@@ -627,8 +627,7 @@ Instruction* InstructionFolder::FoldInstructionToConstant(
Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
- if (!inst->IsFoldableByFoldScalar() &&
- !GetConstantFoldingRules().HasFoldingRule(inst)) {
+ if (!inst->IsFoldableByFoldScalar() && !HasConstFoldingRule(inst)) {
return nullptr;
}
// Collect the values of the constant parameters.
diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp
index 8d68850a0..7a5187010 100644
--- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp
+++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp
@@ -28,6 +28,7 @@ namespace opt {
Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
bool modified = false;
+ analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
// Traverse through all the constant defining instructions. For Normal
// Constants whose values are determined and do not depend on OpUndef
// instructions, records their values in two internal maps: id_to_const_val_
@@ -62,8 +63,8 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
// used in OpSpecConstant{Composite|Op} instructions.
// TODO(qining): If the constant or its type has decoration, we may need
// to skip it.
- if (context()->get_constant_mgr()->GetType(inst) &&
- !context()->get_constant_mgr()->GetType(inst)->decoration_empty())
+ if (const_mgr->GetType(inst) &&
+ !const_mgr->GetType(inst)->decoration_empty())
continue;
switch (SpvOp opcode = inst->opcode()) {
// Records the values of Normal Constants.
@@ -80,15 +81,14 @@ Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
// Constant will be turned in to a Normal Constant. In that case, a
// Constant instance should also be created successfully and recorded
// in the id_to_const_val_ and const_val_to_id_ mapps.
- if (auto const_value =
- context()->get_constant_mgr()->GetConstantFromInst(inst)) {
+ if (auto const_value = const_mgr->GetConstantFromInst(inst)) {
// Need to replace the OpSpecConstantComposite instruction with a
// corresponding OpConstantComposite instruction.
if (opcode == SpvOp::SpvOpSpecConstantComposite) {
inst->SetOpcode(SpvOp::SpvOpConstantComposite);
modified = true;
}
- context()->get_constant_mgr()->MapConstantToInst(const_value, inst);
+ const_mgr->MapConstantToInst(const_value, inst);
}
break;
}
@@ -146,6 +146,7 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder(
Module::inst_iterator* inst_iter_ptr) {
+ analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
// If one of operands to the instruction is not a
// constant, then we cannot fold this spec constant.
for (uint32_t i = 1; i < (*inst_iter_ptr)->NumInOperands(); i++) {
@@ -155,7 +156,7 @@ Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder(
continue;
}
uint32_t id = operand.words[0];
- if (context()->get_constant_mgr()->FindDeclaredConstant(id) == nullptr) {
+ if (const_mgr->FindDeclaredConstant(id) == nullptr) {
return nullptr;
}
}
@@ -202,6 +203,7 @@ Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder(
new_const_inst->InsertAfter(insert_pos);
get_def_use_mgr()->AnalyzeInstDefUse(new_const_inst);
}
+ const_mgr->MapInst(new_const_inst);
return new_const_inst;
}
@@ -285,8 +287,8 @@ utils::SmallVector<uint32_t, 2> EncodeIntegerAsWords(const analysis::Type& type,
Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
Module::inst_iterator* pos) {
const Instruction* inst = &**pos;
- const analysis::Type* result_type =
- context()->get_constant_mgr()->GetType(inst);
+ analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
+ const analysis::Type* result_type = const_mgr->GetType(inst);
SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0));
// Check and collect operands.
std::vector<const analysis::Constant*> operands;
@@ -311,10 +313,9 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
// Scalar operation
const uint32_t result_val =
context()->get_instruction_folder().FoldScalars(spec_opcode, operands);
- auto result_const = context()->get_constant_mgr()->GetConstant(
+ auto result_const = const_mgr->GetConstant(
result_type, EncodeIntegerAsWords(*result_type, result_val));
- return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
- result_const, pos);
+ return const_mgr->BuildInstructionAndAddToModule(result_const, pos);
} else if (result_type->AsVector()) {
// Vector operation
const analysis::Type* element_type =
@@ -325,11 +326,10 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
operands);
std::vector<const analysis::Constant*> result_vector_components;
for (const uint32_t r : result_vec) {
- if (auto rc = context()->get_constant_mgr()->GetConstant(
+ if (auto rc = const_mgr->GetConstant(
element_type, EncodeIntegerAsWords(*element_type, r))) {
result_vector_components.push_back(rc);
- if (!context()->get_constant_mgr()->BuildInstructionAndAddToModule(
- rc, pos)) {
+ if (!const_mgr->BuildInstructionAndAddToModule(rc, pos)) {
assert(false &&
"Failed to build and insert constant declaring instruction "
"for the given vector component constant");
@@ -340,10 +340,8 @@ Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
}
auto new_vec_const = MakeUnique<analysis::VectorConstant>(
result_type->AsVector(), result_vector_components);
- auto reg_vec_const = context()->get_constant_mgr()->RegisterConstant(
- std::move(new_vec_const));
- return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
- reg_vec_const, pos);
+ auto reg_vec_const = const_mgr->RegisterConstant(std::move(new_vec_const));
+ return const_mgr->BuildInstructionAndAddToModule(reg_vec_const, pos);
} else {
// Cannot process invalid component wise operation. The result of component
// wise operation must be of integer or bool scalar or vector of
diff --git a/test/opt/fold_spec_const_op_composite_test.cpp b/test/opt/fold_spec_const_op_composite_test.cpp
index 7eddf7e99..c98a44c3d 100644
--- a/test/opt/fold_spec_const_op_composite_test.cpp
+++ b/test/opt/fold_spec_const_op_composite_test.cpp
@@ -105,6 +105,209 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
builder.GetCode(), builder.GetCode(), /* skip_nop = */ true);
}
+// Test where OpSpecConstantOp depends on another OpSpecConstantOp with
+// CompositeExtract
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, StackedCompositeExtract) {
+ AssemblyBuilder builder;
+ builder.AppendTypesConstantsGlobals({
+ // clang-format off
+ "%uint = OpTypeInt 32 0",
+ "%v3uint = OpTypeVector %uint 3",
+ "%uint_2 = OpConstant %uint 2",
+ "%uint_3 = OpConstant %uint 3",
+ // Folding target:
+ "%composite_0 = OpSpecConstantComposite %v3uint %uint_2 %uint_3 %uint_2",
+ "%op_0 = OpSpecConstantOp %uint CompositeExtract %composite_0 0",
+ "%op_1 = OpSpecConstantOp %uint CompositeExtract %composite_0 1",
+ "%op_2 = OpSpecConstantOp %uint IMul %op_0 %op_1",
+ "%composite_1 = OpSpecConstantComposite %v3uint %op_0 %op_1 %op_2",
+ "%op_3 = OpSpecConstantOp %uint CompositeExtract %composite_1 0",
+ "%op_4 = OpSpecConstantOp %uint IMul %op_2 %op_3",
+ // clang-format on
+ });
+
+ std::vector<const char*> expected = {
+ // clang-format off
+ "OpCapability Shader",
+ "OpCapability Float64",
+ "%1 = OpExtInstImport \"GLSL.std.450\"",
+ "OpMemoryModel Logical GLSL450",
+ "OpEntryPoint Vertex %main \"main\"",
+ "OpName %void \"void\"",
+ "OpName %main_func_type \"main_func_type\"",
+ "OpName %main \"main\"",
+ "OpName %main_func_entry_block \"main_func_entry_block\"",
+ "OpName %uint \"uint\"",
+ "OpName %v3uint \"v3uint\"",
+ "OpName %uint_2 \"uint_2\"",
+ "OpName %uint_3 \"uint_3\"",
+ "OpName %composite_0 \"composite_0\"",
+ "OpName %op_0 \"op_0\"",
+ "OpName %op_1 \"op_1\"",
+ "OpName %op_2 \"op_2\"",
+ "OpName %composite_1 \"composite_1\"",
+ "OpName %op_3 \"op_3\"",
+ "OpName %op_4 \"op_4\"",
+ "%void = OpTypeVoid",
+"%main_func_type = OpTypeFunction %void",
+ "%uint = OpTypeInt 32 0",
+ "%v3uint = OpTypeVector %uint 3",
+ "%uint_2 = OpConstant %uint 2",
+ "%uint_3 = OpConstant %uint 3",
+"%composite_0 = OpConstantComposite %v3uint %uint_2 %uint_3 %uint_2",
+ "%op_0 = OpConstant %uint 2",
+ "%op_1 = OpConstant %uint 3",
+ "%op_2 = OpConstant %uint 6",
+"%composite_1 = OpConstantComposite %v3uint %op_0 %op_1 %op_2",
+"%op_3 = OpConstant %uint 2",
+ "%op_4 = OpConstant %uint 12",
+ "%main = OpFunction %void None %main_func_type",
+"%main_func_entry_block = OpLabel",
+ "OpReturn",
+ "OpFunctionEnd",
+ // clang-format on
+ };
+ SinglePassRunAndCheck<FoldSpecConstantOpAndCompositePass>(
+ builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
+}
+
+// Test where OpSpecConstantOp depends on another OpSpecConstantOp with
+// VectorShuffle
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, StackedVectorShuffle) {
+ AssemblyBuilder builder;
+ builder.AppendTypesConstantsGlobals({
+ // clang-format off
+ "%uint = OpTypeInt 32 0",
+ "%v3uint = OpTypeVector %uint 3",
+ "%uint_1 = OpConstant %uint 1",
+ "%uint_2 = OpConstant %uint 2",
+ "%uint_3 = OpConstant %uint 3",
+ "%uint_4 = OpConstant %uint 4",
+ "%uint_5 = OpConstant %uint 5",
+ "%uint_6 = OpConstant %uint 6",
+ // Folding target:
+ "%composite_0 = OpSpecConstantComposite %v3uint %uint_1 %uint_2 %uint_3",
+ "%composite_1 = OpSpecConstantComposite %v3uint %uint_4 %uint_5 %uint_6",
+ "%vecshuffle = OpSpecConstantOp %v3uint VectorShuffle %composite_0 %composite_1 0 5 3",
+ "%op = OpSpecConstantOp %uint CompositeExtract %vecshuffle 1",
+ // clang-format on
+ });
+
+ std::vector<const char*> expected = {
+ // clang-format off
+ "OpCapability Shader",
+ "OpCapability Float64",
+ "%1 = OpExtInstImport \"GLSL.std.450\"",
+ "OpMemoryModel Logical GLSL450",
+ "OpEntryPoint Vertex %main \"main\"",
+ "OpName %void \"void\"",
+ "OpName %main_func_type \"main_func_type\"",
+ "OpName %main \"main\"",
+ "OpName %main_func_entry_block \"main_func_entry_block\"",
+ "OpName %uint \"uint\"",
+ "OpName %v3uint \"v3uint\"",
+ "OpName %uint_1 \"uint_1\"",
+ "OpName %uint_2 \"uint_2\"",
+ "OpName %uint_3 \"uint_3\"",
+ "OpName %uint_4 \"uint_4\"",
+ "OpName %uint_5 \"uint_5\"",
+ "OpName %uint_6 \"uint_6\"",
+ "OpName %composite_0 \"composite_0\"",
+ "OpName %composite_1 \"composite_1\"",
+ "OpName %vecshuffle \"vecshuffle\"",
+ "OpName %op \"op\"",
+ "%void = OpTypeVoid",
+"%main_func_type = OpTypeFunction %void",
+ "%uint = OpTypeInt 32 0",
+ "%v3uint = OpTypeVector %uint 3",
+ "%uint_1 = OpConstant %uint 1",
+ "%uint_2 = OpConstant %uint 2",
+ "%uint_3 = OpConstant %uint 3",
+ "%uint_4 = OpConstant %uint 4",
+ "%uint_5 = OpConstant %uint 5",
+ "%uint_6 = OpConstant %uint 6",
+"%composite_0 = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_3",
+"%composite_1 = OpConstantComposite %v3uint %uint_4 %uint_5 %uint_6",
+"%vecshuffle = OpConstantComposite %v3uint %uint_1 %uint_6 %uint_4",
+ "%op = OpConstant %uint 6",
+ "%main = OpFunction %void None %main_func_type",
+"%main_func_entry_block = OpLabel",
+ "OpReturn",
+ "OpFunctionEnd",
+ // clang-format on
+ };
+ SinglePassRunAndCheck<FoldSpecConstantOpAndCompositePass>(
+ builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
+}
+
+// Test CompositeExtract with matrix
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeExtractMaxtrix) {
+ AssemblyBuilder builder;
+ builder.AppendTypesConstantsGlobals({
+ // clang-format off
+ "%uint = OpTypeInt 32 0",
+ "%v3uint = OpTypeVector %uint 3",
+ "%mat3x3 = OpTypeMatrix %v3uint 3",
+ "%uint_1 = OpConstant %uint 1",
+ "%uint_2 = OpConstant %uint 2",
+ "%uint_3 = OpConstant %uint 3",
+ // Folding target:
+ "%a = OpSpecConstantComposite %v3uint %uint_1 %uint_1 %uint_1",
+ "%b = OpSpecConstantComposite %v3uint %uint_1 %uint_1 %uint_3",
+ "%c = OpSpecConstantComposite %v3uint %uint_1 %uint_2 %uint_1",
+ "%op = OpSpecConstantComposite %mat3x3 %a %b %c",
+ "%x = OpSpecConstantOp %uint CompositeExtract %op 2 1",
+ "%y = OpSpecConstantOp %uint CompositeExtract %op 1 2",
+ // clang-format on
+ });
+
+ std::vector<const char*> expected = {
+ // clang-format off
+ "OpCapability Shader",
+ "OpCapability Float64",
+ "%1 = OpExtInstImport \"GLSL.std.450\"",
+ "OpMemoryModel Logical GLSL450",
+ "OpEntryPoint Vertex %main \"main\"",
+ "OpName %void \"void\"",
+ "OpName %main_func_type \"main_func_type\"",
+ "OpName %main \"main\"",
+ "OpName %main_func_entry_block \"main_func_entry_block\"",
+ "OpName %uint \"uint\"",
+ "OpName %v3uint \"v3uint\"",
+ "OpName %mat3x3 \"mat3x3\"",
+ "OpName %uint_1 \"uint_1\"",
+ "OpName %uint_2 \"uint_2\"",
+ "OpName %uint_3 \"uint_3\"",
+ "OpName %a \"a\"",
+ "OpName %b \"b\"",
+ "OpName %c \"c\"",
+ "OpName %op \"op\"",
+ "OpName %x \"x\"",
+ "OpName %y \"y\"",
+ "%void = OpTypeVoid",
+"%main_func_type = OpTypeFunction %void",
+ "%uint = OpTypeInt 32 0",
+ "%v3uint = OpTypeVector %uint 3",
+ "%mat3x3 = OpTypeMatrix %v3uint 3",
+ "%uint_1 = OpConstant %uint 1",
+ "%uint_2 = OpConstant %uint 2",
+ "%uint_3 = OpConstant %uint 3",
+ "%a = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1",
+ "%b = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_3",
+ "%c = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_1",
+ "%op = OpConstantComposite %mat3x3 %a %b %c",
+ "%x = OpConstant %uint 2",
+ "%y = OpConstant %uint 3",
+ "%main = OpFunction %void None %main_func_type",
+"%main_func_entry_block = OpLabel",
+ "OpReturn",
+ "OpFunctionEnd",
+ // clang-format on
+ };
+ SinglePassRunAndCheck<FoldSpecConstantOpAndCompositePass>(
+ builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
+}
+
// All types and some common constants that are potentially required in
// FoldSpecConstantOpAndCompositeTest.
std::vector<std::string> CommonTypesAndConstants() {
@@ -199,7 +402,7 @@ std::string StripOpNameInstructions(const std::string& str) {
struct FoldSpecConstantOpAndCompositePassTestCase {
// Original constants with unfolded spec constants.
std::vector<std::string> original;
- // Expected cosntants after folding.
+ // Expected constant after folding.
std::vector<std::string> expected;
};