diff options
author | Pedro J. Estébanez <pedrojrulez@gmail.com> | 2022-07-20 23:00:35 +0300 |
---|---|---|
committer | Pedro J. Estébanez <pedrojrulez@gmail.com> | 2022-07-22 14:39:37 +0300 |
commit | 1fe470b199909a7c16b8e095c4bcdf7fbf8d1592 (patch) | |
tree | b69f79052386738484fec7aba4c1cfcfb2f98af0 | |
parent | d8d051381f65b9606fb8016c79b7c3bab872eec3 (diff) |
HLSL: Implement GroupOperation(Inclusive/Exclusive)Scan.
-rw-r--r-- | reference/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp | 2 | ||||
-rw-r--r-- | shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp | 4 | ||||
-rw-r--r-- | spirv_hlsl.cpp | 28 |
3 files changed, 29 insertions, 5 deletions
diff --git a/reference/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp b/reference/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp index b2f5a5a1..4c11a4b1 100644 --- a/reference/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp +++ b/reference/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp @@ -21,6 +21,8 @@ void comp_main() float3 first = WaveReadLaneFirst(20.0f.xxx); uint4 ballot_value = WaveActiveBallot(true); uint bit_count = countbits(ballot_value.x) + countbits(ballot_value.y) + countbits(ballot_value.z) + countbits(ballot_value.w); + uint inclusive_bit_count = countbits(ballot_value.x & gl_SubgroupLeMask.x) + countbits(ballot_value.y & gl_SubgroupLeMask.y) + countbits(ballot_value.z & gl_SubgroupLeMask.z) + countbits(ballot_value.w & gl_SubgroupLeMask.w); + uint exclusive_bit_count = countbits(ballot_value.x & gl_SubgroupLtMask.x) + countbits(ballot_value.y & gl_SubgroupLtMask.y) + countbits(ballot_value.z & gl_SubgroupLtMask.z) + countbits(ballot_value.w & gl_SubgroupLtMask.w); uint shuffled = WaveReadLaneAt(10u, 8u); uint shuffled_xor = WaveReadLaneAt(30u, WaveGetLaneIndex() ^ 8u); uint shuffled_up = WaveReadLaneAt(20u, WaveGetLaneIndex() - 4u); diff --git a/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp b/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp index 0f29d445..bbda0efd 100644 --- a/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp +++ b/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp @@ -40,8 +40,8 @@ void main() //bool inverse_ballot_value = subgroupInverseBallot(ballot_value); //bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u); uint bit_count = subgroupBallotBitCount(ballot_value); - //uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value); - //uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value); + uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value); + uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value); //uint lsb = subgroupBallotFindLSB(ballot_value); //uint msb = subgroupBallotFindMSB(ballot_value); diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index c5a37d2c..d0aa385f 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -4728,9 +4728,9 @@ void CompilerHLSL::emit_subgroup_op(const Instruction &i) case OpGroupNonUniformBallotBitCount: { auto operation = static_cast<GroupOperation>(ops[3]); + bool forward = should_forward(ops[4]); if (operation == GroupOperationReduce) { - bool forward = should_forward(ops[4]); auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(", to_enclosed_expression(ops[4]), ".y)"); auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(", @@ -4739,9 +4739,31 @@ void CompilerHLSL::emit_subgroup_op(const Instruction &i) inherit_expression_dependencies(id, ops[4]); } else if (operation == GroupOperationInclusiveScan) - SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL."); + { + auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x & gl_SubgroupLeMask.x) + countbits(", + to_enclosed_expression(ops[4]), ".y & gl_SubgroupLeMask.y)"); + auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z & gl_SubgroupLeMask.z) + countbits(", + to_enclosed_expression(ops[4]), ".w & gl_SubgroupLeMask.w)"); + emit_op(result_type, id, join(left, " + ", right), forward); + if (!active_input_builtins.get(BuiltInSubgroupLeMask)) + { + active_input_builtins.set(BuiltInSubgroupLeMask); + force_recompile_guarantee_forward_progress(); + } + } else if (operation == GroupOperationExclusiveScan) - SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL."); + { + auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x & gl_SubgroupLtMask.x) + countbits(", + to_enclosed_expression(ops[4]), ".y & gl_SubgroupLtMask.y)"); + auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z & gl_SubgroupLtMask.z) + countbits(", + to_enclosed_expression(ops[4]), ".w & gl_SubgroupLtMask.w)"); + emit_op(result_type, id, join(left, " + ", right), forward); + if (!active_input_builtins.get(BuiltInSubgroupLtMask)) + { + active_input_builtins.set(BuiltInSubgroupLtMask); + force_recompile_guarantee_forward_progress(); + } + } else SPIRV_CROSS_THROW("Invalid BitCount operation."); break; |