diff options
author | Hans-Kristian Arntzen <post@arntzen-software.no> | 2022-02-28 13:58:33 +0300 |
---|---|---|
committer | Hans-Kristian Arntzen <post@arntzen-software.no> | 2022-02-28 13:58:33 +0300 |
commit | 5555f2784b0afadbf6e9ab8d97dc89515ef57f44 (patch) | |
tree | 704094cc5bcc7cb8fd2a880621630d4e2de67c0b | |
parent | c08ee860c8ad4f020b19f9a372013e99cd4a00d9 (diff) |
MSL: Refactor and fix use of quadgroup vs simdgroup.
-rw-r--r-- | reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl23.ios.simd.comp | 2 | ||||
-rw-r--r-- | spirv_msl.cpp | 56 | ||||
-rw-r--r-- | spirv_msl.hpp | 7 |
3 files changed, 35 insertions, 30 deletions
diff --git a/reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl23.ios.simd.comp b/reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl23.ios.simd.comp index 1614bf3e..71916ebb 100644 --- a/reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl23.ios.simd.comp +++ b/reference/shaders-msl-no-opt/comp/subgroups.nocompat.invalid.vk.msl23.ios.simd.comp @@ -221,7 +221,7 @@ struct SSBO constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); -kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[quadgroups_per_threadgroup]], uint gl_SubgroupID [[quadgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_quadgroup]]) +kernel void main0(device SSBO& _9 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[thread_execution_width]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]]) { uint4 gl_SubgroupEqMask = uint4(1 << gl_SubgroupInvocationID, uint3(0)); uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, gl_SubgroupInvocationID, gl_SubgroupSize - gl_SubgroupInvocationID), uint3(0)); diff --git a/spirv_msl.cpp b/spirv_msl.cpp index e46f68e2..d125aeaf 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -5370,7 +5370,7 @@ void CompilerMSL::emit_custom_functions() statement("template<typename T>"); statement("inline T spvSubgroupBroadcast(T value, ushort lane)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return quad_broadcast(value, lane);"); else statement("return simd_broadcast(value, lane);"); @@ -5379,7 +5379,7 @@ void CompilerMSL::emit_custom_functions() statement("template<>"); statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return !!quad_broadcast((ushort)value, lane);"); else statement("return !!simd_broadcast((ushort)value, lane);"); @@ -5388,7 +5388,7 @@ void CompilerMSL::emit_custom_functions() statement("template<uint N>"); statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);"); else statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);"); @@ -5400,7 +5400,7 @@ void CompilerMSL::emit_custom_functions() statement("template<typename T>"); statement("inline T spvSubgroupBroadcastFirst(T value)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return quad_broadcast_first(value);"); else statement("return simd_broadcast_first(value);"); @@ -5409,7 +5409,7 @@ void CompilerMSL::emit_custom_functions() statement("template<>"); statement("inline bool spvSubgroupBroadcastFirst(bool value)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return !!quad_broadcast_first((ushort)value);"); else statement("return !!simd_broadcast_first((ushort)value);"); @@ -5418,7 +5418,7 @@ void CompilerMSL::emit_custom_functions() statement("template<uint N>"); statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);"); else statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);"); @@ -5429,7 +5429,7 @@ void CompilerMSL::emit_custom_functions() case SPVFuncImplSubgroupBallot: statement("inline uint4 spvSubgroupBallot(bool value)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) { statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);"); } @@ -5557,7 +5557,7 @@ void CompilerMSL::emit_custom_functions() statement("template<typename T>"); statement("inline bool spvSubgroupAllEqual(T value)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return quad_all(all(value == quad_broadcast_first(value)));"); else statement("return simd_all(all(value == simd_broadcast_first(value)));"); @@ -5566,7 +5566,7 @@ void CompilerMSL::emit_custom_functions() statement("template<>"); statement("inline bool spvSubgroupAllEqual(bool value)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return quad_all(value) || !quad_any(value);"); else statement("return simd_all(value) || !simd_any(value);"); @@ -5575,7 +5575,7 @@ void CompilerMSL::emit_custom_functions() statement("template<uint N>"); statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));"); else statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));"); @@ -5587,7 +5587,7 @@ void CompilerMSL::emit_custom_functions() statement("template<typename T>"); statement("inline T spvSubgroupShuffle(T value, ushort lane)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return quad_shuffle(value, lane);"); else statement("return simd_shuffle(value, lane);"); @@ -5596,7 +5596,7 @@ void CompilerMSL::emit_custom_functions() statement("template<>"); statement("inline bool spvSubgroupShuffle(bool value, ushort lane)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return !!quad_shuffle((ushort)value, lane);"); else statement("return !!simd_shuffle((ushort)value, lane);"); @@ -5605,7 +5605,7 @@ void CompilerMSL::emit_custom_functions() statement("template<uint N>"); statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);"); else statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);"); @@ -5617,7 +5617,7 @@ void CompilerMSL::emit_custom_functions() statement("template<typename T>"); statement("inline T spvSubgroupShuffleXor(T value, ushort mask)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return quad_shuffle_xor(value, mask);"); else statement("return simd_shuffle_xor(value, mask);"); @@ -5626,7 +5626,7 @@ void CompilerMSL::emit_custom_functions() statement("template<>"); statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return !!quad_shuffle_xor((ushort)value, mask);"); else statement("return !!simd_shuffle_xor((ushort)value, mask);"); @@ -5635,7 +5635,7 @@ void CompilerMSL::emit_custom_functions() statement("template<uint N>"); statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);"); else statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);"); @@ -5647,7 +5647,7 @@ void CompilerMSL::emit_custom_functions() statement("template<typename T>"); statement("inline T spvSubgroupShuffleUp(T value, ushort delta)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return quad_shuffle_up(value, delta);"); else statement("return simd_shuffle_up(value, delta);"); @@ -5656,7 +5656,7 @@ void CompilerMSL::emit_custom_functions() statement("template<>"); statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return !!quad_shuffle_up((ushort)value, delta);"); else statement("return !!simd_shuffle_up((ushort)value, delta);"); @@ -5665,7 +5665,7 @@ void CompilerMSL::emit_custom_functions() statement("template<uint N>"); statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);"); else statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);"); @@ -5677,7 +5677,7 @@ void CompilerMSL::emit_custom_functions() statement("template<typename T>"); statement("inline T spvSubgroupShuffleDown(T value, ushort delta)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return quad_shuffle_down(value, delta);"); else statement("return simd_shuffle_down(value, delta);"); @@ -5686,7 +5686,7 @@ void CompilerMSL::emit_custom_functions() statement("template<>"); statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return !!quad_shuffle_down((ushort)value, delta);"); else statement("return !!simd_shuffle_down((ushort)value, delta);"); @@ -5695,7 +5695,7 @@ void CompilerMSL::emit_custom_functions() statement("template<uint N>"); statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)"); begin_scope(); - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);"); else statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);"); @@ -13972,7 +13972,7 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i) switch (op) { case OpGroupNonUniformElect: - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) emit_op(result_type, id, "quad_is_first()", false); else emit_op(result_type, id, "simd_is_first()", false); @@ -14045,14 +14045,14 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i) break; case OpGroupNonUniformAll: - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) emit_unary_func_op(result_type, id, ops[3], "quad_all"); else emit_unary_func_op(result_type, id, ops[3], "simd_all"); break; case OpGroupNonUniformAny: - if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions) + if (msl_options.use_quadgroup_operation()) emit_unary_func_op(result_type, id, ops[3], "quad_any"); else emit_unary_func_op(result_type, id, ops[3], "simd_any"); @@ -14550,7 +14550,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin) SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation."); if (!msl_options.supports_msl_version(2)) SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0."); - return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup"; + return msl_options.use_quadgroup_operation() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup"; case BuiltInSubgroupId: if (msl_options.emulate_subgroups) @@ -14558,7 +14558,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin) SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation."); if (!msl_options.supports_msl_version(2)) SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0."); - return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup"; + return msl_options.use_quadgroup_operation() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup"; case BuiltInSubgroupLocalInvocationId: if (msl_options.emulate_subgroups) @@ -14577,7 +14577,7 @@ string CompilerMSL::builtin_qualifier(BuiltIn builtin) // We are generating a Metal kernel function. if (!msl_options.supports_msl_version(2)) SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0."); - return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup"; + return msl_options.use_quadgroup_operation() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup"; } else SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function."); diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 6591e47c..641f49e9 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -393,7 +393,7 @@ public: // and will be addressed using the current ViewIndex. bool arrayed_subpass_input = false; - // Whether to use SIMD-group or quadgroup functions to implement group nnon-uniform + // Whether to use SIMD-group or quadgroup functions to implement group non-uniform // operations. Some GPUs on iOS do not support the SIMD-group functions, only the // quadgroup functions. bool ios_use_simdgroup_functions = false; @@ -445,6 +445,11 @@ public: return platform == macOS; } + bool use_quadgroup_operation() const + { + return is_ios() && !ios_use_simdgroup_functions; + } + void set_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) { msl_version = make_msl_version(major, minor, patch); |