From 54d4e77fa5599b855f5c463646c0e8922d5e6064 Mon Sep 17 00:00:00 2001 From: Spencer Fricke Date: Wed, 9 Nov 2022 00:50:42 +0900 Subject: spirv-opt: Add const folding for CompositeInsert (#4943) * spirv-opt: Add const folding pass for CompositeInsert * spirv-opt: Fix anas stack-use-after-scope --- source/opt/const_folding_rules.cpp | 78 ++++++++++++++++++++++++++ test/opt/fold_spec_const_op_composite_test.cpp | 66 ++++++++++++++++++++++ 2 files changed, 144 insertions(+) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index 64475a6d5..6d80fbb52 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -120,6 +120,83 @@ ConstantFoldingRule FoldExtractWithConstants() { }; } +// Folds an OpcompositeInsert where input is a composite constant. +ConstantFoldingRule FoldInsertWithConstants() { + return [](IRContext* context, Instruction* inst, + const std::vector& constants) + -> const analysis::Constant* { + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Constant* object = constants[0]; + const analysis::Constant* composite = constants[1]; + if (object == nullptr || composite == nullptr) { + return nullptr; + } + + // If there is more than 1 index, then each additional constant used by the + // index will need to be recreated to use the inserted object. + std::vector chain; + std::vector components; + const analysis::Type* type = nullptr; + + // Work down hierarchy and add all the indexes, not including the final + // index. + for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { + if (i != inst->NumInOperands() - 1) { + chain.push_back(composite); + } + const uint32_t index = inst->GetSingleWordInOperand(i); + components = composite->AsCompositeConstant()->GetComponents(); + type = composite->AsCompositeConstant()->type(); + composite = components[index]; + } + + // Final index in hierarchy is inserted with new object. + const uint32_t final_index = + inst->GetSingleWordInOperand(inst->NumInOperands() - 1); + std::vector ids; + for (size_t i = 0; i < components.size(); i++) { + const analysis::Constant* constant = + (i == final_index) ? object : components[i]; + Instruction* member_inst = const_mgr->GetDefiningInstruction(constant); + ids.push_back(member_inst->result_id()); + } + const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids); + + // Work backwards up the chain and replace each index with new constant. + for (size_t i = chain.size(); i > 0; i--) { + // Need to insert any previous instruction into the module first. + // Can't just insert in types_values_begin() because it will move above + // where the types are declared + for (Module::inst_iterator inst_iter = context->types_values_begin(); + inst_iter != context->types_values_end(); ++inst_iter) { + Instruction* x = &*inst_iter; + if (inst->result_id() == x->result_id()) { + const_mgr->BuildInstructionAndAddToModule(new_constant, &inst_iter); + break; + } + } + + composite = chain[i - 1]; + components = composite->AsCompositeConstant()->GetComponents(); + type = composite->AsCompositeConstant()->type(); + ids.clear(); + for (size_t k = 0; k < components.size(); k++) { + const uint32_t index = + inst->GetSingleWordInOperand(1 + static_cast(i)); + const analysis::Constant* constant = + (k == index) ? new_constant : components[k]; + const uint32_t constant_id = + const_mgr->FindDeclaredConstant(constant, 0); + ids.push_back(constant_id); + } + new_constant = const_mgr->GetConstant(type, ids); + } + + // If multiple constants were created, only need to return the top index. + return new_constant; + }; +} + ConstantFoldingRule FoldVectorShuffleWithConstants() { return [](IRContext* context, Instruction* inst, const std::vector& constants) @@ -1410,6 +1487,7 @@ void ConstantFoldingRules::AddFoldingRules() { rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants()); rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants()); + rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants()); rules_[spv::Op::OpConvertFToS].push_back(FoldFToI()); rules_[spv::Op::OpConvertFToU].push_back(FoldFToI()); diff --git a/test/opt/fold_spec_const_op_composite_test.cpp b/test/opt/fold_spec_const_op_composite_test.cpp index c98a44c3d..e2374c573 100644 --- a/test/opt/fold_spec_const_op_composite_test.cpp +++ b/test/opt/fold_spec_const_op_composite_test.cpp @@ -308,6 +308,72 @@ TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeExtractMaxtrix) { builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true); } +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertVector) { + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %v3uint = OpTypeVector %uint 3 + %uint_2 = OpConstant %uint 2 + %uint_3 = OpConstant %uint 3 + %8 = OpConstantNull %uint + %9 = OpSpecConstantComposite %v3uint %uint_2 %uint_2 %uint_2 + ; CHECK: %15 = OpConstantComposite %v3uint %uint_3 %uint_2 %uint_2 + ; CHECK: %uint_3_0 = OpConstant %uint 3 + ; CHECK: %17 = OpConstantComposite %v3uint %8 %uint_2 %uint_2 + ; CHECK: %18 = OpConstantNull %uint + %10 = OpSpecConstantOp %v3uint CompositeInsert %uint_3 %9 0 + %11 = OpSpecConstantOp %uint CompositeExtract %10 0 + %12 = OpSpecConstantOp %v3uint CompositeInsert %8 %9 0 + %13 = OpSpecConstantOp %uint CompositeExtract %12 0 + %1 = OpFunction %void None %3 + %14 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(test, false); +} + +TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrix) { + const std::string test = + R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %1 "main" + OpExecutionMode %1 LocalSize 1 1 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v3float = OpTypeVector %float 3 +%mat3v3float = OpTypeMatrix %v3float 3 + %float_1 = OpConstant %float 1 + %float_2 = OpConstant %float 2 + %9 = OpSpecConstantComposite %v3float %float_1 %float_1 %float_1 + %10 = OpSpecConstantComposite %v3float %float_1 %float_1 %float_1 + %11 = OpSpecConstantComposite %v3float %float_1 %float_2 %float_1 + %12 = OpSpecConstantComposite %mat3v3float %9 %10 %11 + ; CHECK: %float_2_0 = OpConstant %float 2 + ; CHECK: %18 = OpConstantComposite %v3float %float_1 %float_1 %float_2 + ; CHECK: %19 = OpConstantComposite %mat3v3float %9 %18 %11 + ; CHECK: %float_2_1 = OpConstant %float 2 + %13 = OpSpecConstantOp %float CompositeExtract %12 2 1 + %14 = OpSpecConstantOp %mat3v3float CompositeInsert %13 %12 1 2 + %15 = OpSpecConstantOp %float CompositeExtract %14 1 2 + %1 = OpFunction %void None %3 + %16 = OpLabel + OpReturn + OpFunctionEnd +)"; + + SinglePassRunAndMatch(test, false); +} + // All types and some common constants that are potentially required in // FoldSpecConstantOpAndCompositeTest. std::vector CommonTypesAndConstants() { -- cgit v1.2.3