Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/KhronosGroup/SPIRV-Cross.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPedro J. Estébanez <pedrojrulez@gmail.com>2022-07-20 23:00:35 +0300
committerPedro J. Estébanez <pedrojrulez@gmail.com>2022-07-22 14:39:37 +0300
commit1fe470b199909a7c16b8e095c4bcdf7fbf8d1592 (patch)
treeb69f79052386738484fec7aba4c1cfcfb2f98af0
parentd8d051381f65b9606fb8016c79b7c3bab872eec3 (diff)
HLSL: Implement GroupOperation(Inclusive/Exclusive)Scan.
-rw-r--r--reference/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp2
-rw-r--r--shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp4
-rw-r--r--spirv_hlsl.cpp28
3 files changed, 29 insertions, 5 deletions
diff --git a/reference/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp b/reference/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp
index b2f5a5a1..4c11a4b1 100644
--- a/reference/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp
+++ b/reference/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp
@@ -21,6 +21,8 @@ void comp_main()
float3 first = WaveReadLaneFirst(20.0f.xxx);
uint4 ballot_value = WaveActiveBallot(true);
uint bit_count = countbits(ballot_value.x) + countbits(ballot_value.y) + countbits(ballot_value.z) + countbits(ballot_value.w);
+ uint inclusive_bit_count = countbits(ballot_value.x & gl_SubgroupLeMask.x) + countbits(ballot_value.y & gl_SubgroupLeMask.y) + countbits(ballot_value.z & gl_SubgroupLeMask.z) + countbits(ballot_value.w & gl_SubgroupLeMask.w);
+ uint exclusive_bit_count = countbits(ballot_value.x & gl_SubgroupLtMask.x) + countbits(ballot_value.y & gl_SubgroupLtMask.y) + countbits(ballot_value.z & gl_SubgroupLtMask.z) + countbits(ballot_value.w & gl_SubgroupLtMask.w);
uint shuffled = WaveReadLaneAt(10u, 8u);
uint shuffled_xor = WaveReadLaneAt(30u, WaveGetLaneIndex() ^ 8u);
uint shuffled_up = WaveReadLaneAt(20u, WaveGetLaneIndex() - 4u);
diff --git a/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp b/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp
index 0f29d445..bbda0efd 100644
--- a/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp
+++ b/shaders-hlsl-no-opt/comp/subgroups.invalid.nofxc.sm60.comp
@@ -40,8 +40,8 @@ void main()
//bool inverse_ballot_value = subgroupInverseBallot(ballot_value);
//bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u);
uint bit_count = subgroupBallotBitCount(ballot_value);
- //uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value);
- //uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value);
+ uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value);
+ uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value);
//uint lsb = subgroupBallotFindLSB(ballot_value);
//uint msb = subgroupBallotFindMSB(ballot_value);
diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp
index c5a37d2c..d0aa385f 100644
--- a/spirv_hlsl.cpp
+++ b/spirv_hlsl.cpp
@@ -4728,9 +4728,9 @@ void CompilerHLSL::emit_subgroup_op(const Instruction &i)
case OpGroupNonUniformBallotBitCount:
{
auto operation = static_cast<GroupOperation>(ops[3]);
+ bool forward = should_forward(ops[4]);
if (operation == GroupOperationReduce)
{
- bool forward = should_forward(ops[4]);
auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(",
to_enclosed_expression(ops[4]), ".y)");
auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(",
@@ -4739,9 +4739,31 @@ void CompilerHLSL::emit_subgroup_op(const Instruction &i)
inherit_expression_dependencies(id, ops[4]);
}
else if (operation == GroupOperationInclusiveScan)
- SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL.");
+ {
+ auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x & gl_SubgroupLeMask.x) + countbits(",
+ to_enclosed_expression(ops[4]), ".y & gl_SubgroupLeMask.y)");
+ auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z & gl_SubgroupLeMask.z) + countbits(",
+ to_enclosed_expression(ops[4]), ".w & gl_SubgroupLeMask.w)");
+ emit_op(result_type, id, join(left, " + ", right), forward);
+ if (!active_input_builtins.get(BuiltInSubgroupLeMask))
+ {
+ active_input_builtins.set(BuiltInSubgroupLeMask);
+ force_recompile_guarantee_forward_progress();
+ }
+ }
else if (operation == GroupOperationExclusiveScan)
- SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL.");
+ {
+ auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x & gl_SubgroupLtMask.x) + countbits(",
+ to_enclosed_expression(ops[4]), ".y & gl_SubgroupLtMask.y)");
+ auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z & gl_SubgroupLtMask.z) + countbits(",
+ to_enclosed_expression(ops[4]), ".w & gl_SubgroupLtMask.w)");
+ emit_op(result_type, id, join(left, " + ", right), forward);
+ if (!active_input_builtins.get(BuiltInSubgroupLtMask))
+ {
+ active_input_builtins.set(BuiltInSubgroupLtMask);
+ force_recompile_guarantee_forward_progress();
+ }
+ }
else
SPIRV_CROSS_THROW("Invalid BitCount operation.");
break;