Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/KhronosGroup/SPIRV-Cross.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHans-Kristian Arntzen <post@arntzen-software.no>2022-02-28 13:58:33 +0300
committerHans-Kristian Arntzen <post@arntzen-software.no>2022-02-28 13:58:33 +0300
commit5555f2784b0afadbf6e9ab8d97dc89515ef57f44 (patch)
tree704094cc5bcc7cb8fd2a880621630d4e2de67c0b
parentc08ee860c8ad4f020b19f9a372013e99cd4a00d9 (diff)
MSL: Refactor and fix use of quadgroup vs simdgroup.
-rw-r--r--reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl23.ios.simd.comp2
-rw-r--r--spirv_msl.cpp56
-rw-r--r--spirv_msl.hpp7
3 files changed, 35 insertions, 30 deletions
diff --git a/reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl23.ios.simd.comp b/reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl23.ios.simd.comp
index 1614bf3e..71916ebb 100644
--- a/reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl23.ios.simd.comp
+++ b/reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl23.ios.simd.comp
@@ -221,7 +221,7 @@ struct SSBO
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
-kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[quadgroups_per_threadgroup]], uint gl_SubgroupID [[quadgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_quadgroup]])
+kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]])
{
uint4 gl_SubgroupEqMask = uint4(1 << gl_SubgroupInvocationID, uint3(0));
uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, gl_SubgroupInvocationID, gl_SubgroupSize - gl_SubgroupInvocationID), uint3(0));
diff --git a/spirv_msl.cpp b/spirv_msl.cpp
index e46f68e2..d125aeaf 100644
--- a/spirv_msl.cpp
+++ b/spirv_msl.cpp
@@ -5370,7 +5370,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupBroadcast(T value, ushort lane)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return quad_broadcast(value, lane);");
else
statement("return simd_broadcast(value, lane);");
@@ -5379,7 +5379,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return !!quad_broadcast((ushort)value, lane);");
else
statement("return !!simd_broadcast((ushort)value, lane);");
@@ -5388,7 +5388,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
else
statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
@@ -5400,7 +5400,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupBroadcastFirst(T value)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return quad_broadcast_first(value);");
else
statement("return simd_broadcast_first(value);");
@@ -5409,7 +5409,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupBroadcastFirst(bool value)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return !!quad_broadcast_first((ushort)value);");
else
statement("return !!simd_broadcast_first((ushort)value);");
@@ -5418,7 +5418,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
else
statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
@@ -5429,7 +5429,7 @@ void CompilerMSL::emit_custom_functions()
case SPVFuncImplSubgroupBallot:
statement("inline uint4 spvSubgroupBallot(bool value)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
{
statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
}
@@ -5557,7 +5557,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline bool spvSubgroupAllEqual(T value)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return quad_all(all(value == quad_broadcast_first(value)));");
else
statement("return simd_all(all(value == simd_broadcast_first(value)));");
@@ -5566,7 +5566,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupAllEqual(bool value)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return quad_all(value) || !quad_any(value);");
else
statement("return simd_all(value) || !simd_any(value);");
@@ -5575,7 +5575,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
else
statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
@@ -5587,7 +5587,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupShuffle(T value, ushort lane)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle(value, lane);");
else
statement("return simd_shuffle(value, lane);");
@@ -5596,7 +5596,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupShuffle(bool value, ushort lane)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle((ushort)value, lane);");
else
statement("return !!simd_shuffle((ushort)value, lane);");
@@ -5605,7 +5605,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
else
statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
@@ -5617,7 +5617,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupShuffleXor(T value, ushort mask)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle_xor(value, mask);");
else
statement("return simd_shuffle_xor(value, mask);");
@@ -5626,7 +5626,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle_xor((ushort)value, mask);");
else
statement("return !!simd_shuffle_xor((ushort)value, mask);");
@@ -5635,7 +5635,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
else
statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
@@ -5647,7 +5647,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupShuffleUp(T value, ushort delta)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle_up(value, delta);");
else
statement("return simd_shuffle_up(value, delta);");
@@ -5656,7 +5656,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle_up((ushort)value, delta);");
else
statement("return !!simd_shuffle_up((ushort)value, delta);");
@@ -5665,7 +5665,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
else
statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
@@ -5677,7 +5677,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<typename T>");
statement("inline T spvSubgroupShuffleDown(T value, ushort delta)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return quad_shuffle_down(value, delta);");
else
statement("return simd_shuffle_down(value, delta);");
@@ -5686,7 +5686,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<>");
statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return !!quad_shuffle_down((ushort)value, delta);");
else
statement("return !!simd_shuffle_down((ushort)value, delta);");
@@ -5695,7 +5695,7 @@ void CompilerMSL::emit_custom_functions()
statement("template<uint N>");
statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
begin_scope();
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
else
statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
@@ -13972,7 +13972,7 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
switch (op)
{
case OpGroupNonUniformElect:
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
emit_op(result_type, id, "quad_is_first()", false);
else
emit_op(result_type, id, "simd_is_first()", false);
@@ -14045,14 +14045,14 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
break;
case OpGroupNonUniformAll:
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
emit_unary_func_op(result_type, id, ops[3], "quad_all");
else
emit_unary_func_op(result_type, id, ops[3], "simd_all");
break;
case OpGroupNonUniformAny:
- if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
+ if (msl_options.use_quadgroup_operation())
emit_unary_func_op(result_type, id, ops[3], "quad_any");
else
emit_unary_func_op(result_type, id, ops[3], "simd_any");
@@ -14550,7 +14550,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
- return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
+ return msl_options.use_quadgroup_operation() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
case BuiltInSubgroupId:
if (msl_options.emulate_subgroups)
@@ -14558,7 +14558,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
- return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
+ return msl_options.use_quadgroup_operation() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
case BuiltInSubgroupLocalInvocationId:
if (msl_options.emulate_subgroups)
@@ -14577,7 +14577,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin)
// We are generating a Metal kernel function.
if (!msl_options.supports_msl_version(2))
SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
- return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
+ return msl_options.use_quadgroup_operation() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
}
else
SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");
diff --git a/spirv_msl.hpp b/spirv_msl.hpp
index 6591e47c..641f49e9 100644
--- a/spirv_msl.hpp
+++ b/spirv_msl.hpp
@@ -393,7 +393,7 @@ public:
// and will be addressed using the current ViewIndex.
bool arrayed_subpass_input = false;
- // Whether to use SIMD-group or quadgroup functions to implement group nnon-uniform
+ // Whether to use SIMD-group or quadgroup functions to implement group non-uniform
// operations. Some GPUs on iOS do not support the SIMD-group functions, only the
// quadgroup functions.
bool ios_use_simdgroup_functions = false;
@@ -445,6 +445,11 @@ public:
return platform == macOS;
}
+ bool use_quadgroup_operation() const
+ {
+ return is_ios() && !ios_use_simdgroup_functions;
+ }
+
void set_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0)
{
msl_version = make_msl_version(major, minor, patch);