From 4de9d6c2b6a69cfc87a2989674b475268994119c Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Mon, 31 Oct 2022 13:05:56 +0100 Subject: MSL: Handle implicit integer promotion rules. MSL inherits the behavior of C where arithmetic on small types are implicitly converted to int. SPIR-V does not have this behavior, so make sure that arithmetic results are handled correctly. --- .../comp/implicit-integer-promotion.comp | 93 ++++++++++++++++++++++ .../shaders-msl-no-opt/comp/int16min-literal.comp | 2 +- .../comp/implicit-integer-promotion.comp | 85 ++++++++++++++++++++ spirv_common.hpp | 27 +++++++ spirv_glsl.cpp | 55 ++++++++++--- spirv_glsl.hpp | 4 +- spirv_hlsl.cpp | 2 +- spirv_msl.cpp | 4 +- 8 files changed, 259 insertions(+), 13 deletions(-) create mode 100644 reference/shaders-msl-no-opt/comp/implicit-integer-promotion.comp create mode 100644 shaders-msl-no-opt/comp/implicit-integer-promotion.comp diff --git a/reference/shaders-msl-no-opt/comp/implicit-integer-promotion.comp b/reference/shaders-msl-no-opt/comp/implicit-integer-promotion.comp new file mode 100644 index 00000000..5c3ce49e --- /dev/null +++ b/reference/shaders-msl-no-opt/comp/implicit-integer-promotion.comp @@ -0,0 +1,93 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" + +#include +#include + +using namespace metal; + +struct BUF0 +{ + half2 f16s; + ushort2 u16; + short2 i16; + ushort4 u16s; + short4 i16s; + half f16; +}; + +static inline __attribute__((always_inline)) +void test_u16(device BUF0& v_24) +{ + v_24.f16 += as_type(ushort(((device ushort*)&v_24.u16)[0u] + ((device ushort*)&v_24.u16)[1u])); + v_24.f16 += as_type(ushort(((device ushort*)&v_24.u16)[0u] - ((device ushort*)&v_24.u16)[1u])); + v_24.f16 += as_type(ushort(((device ushort*)&v_24.u16)[0u] * ((device ushort*)&v_24.u16)[1u])); + v_24.f16 += as_type(ushort(((device ushort*)&v_24.u16)[0u] / ((device ushort*)&v_24.u16)[1u])); + v_24.f16 += as_type(ushort(((device ushort*)&v_24.u16)[0u] % ((device ushort*)&v_24.u16)[1u])); + v_24.f16 += as_type(ushort(((device ushort*)&v_24.u16)[0u] << ((device ushort*)&v_24.u16)[1u])); + v_24.f16 += as_type(ushort(((device ushort*)&v_24.u16)[0u] >> ((device ushort*)&v_24.u16)[1u])); + v_24.f16 += as_type(ushort(~((device ushort*)&v_24.u16)[0u])); + v_24.f16 += as_type(ushort(-((device ushort*)&v_24.u16)[0u])); + v_24.f16 += as_type(ushort(((device ushort*)&v_24.u16)[0u] ^ ((device ushort*)&v_24.u16)[1u])); + v_24.f16 += as_type(ushort(((device ushort*)&v_24.u16)[0u] & ((device ushort*)&v_24.u16)[1u])); + v_24.f16 += as_type(ushort(((device ushort*)&v_24.u16)[0u] | ((device ushort*)&v_24.u16)[1u])); +} + +static inline __attribute__((always_inline)) +void test_i16(device BUF0& v_24) +{ + v_24.f16 += as_type(short(((device short*)&v_24.i16)[0u] + ((device short*)&v_24.i16)[1u])); + v_24.f16 += as_type(short(((device short*)&v_24.i16)[0u] - ((device short*)&v_24.i16)[1u])); + v_24.f16 += as_type(short(((device short*)&v_24.i16)[0u] * ((device short*)&v_24.i16)[1u])); + v_24.f16 += as_type(short(((device short*)&v_24.i16)[0u] / ((device short*)&v_24.i16)[1u])); + v_24.f16 += as_type(short(((device short*)&v_24.i16)[0u] % ((device short*)&v_24.i16)[1u])); + v_24.f16 += as_type(short(((device short*)&v_24.i16)[0u] << ((device short*)&v_24.i16)[1u])); + v_24.f16 += as_type(short(((device short*)&v_24.i16)[0u] >> ((device short*)&v_24.i16)[1u])); + v_24.f16 += as_type(short(~((device short*)&v_24.i16)[0u])); + v_24.f16 += as_type(short(-((device short*)&v_24.i16)[0u])); + v_24.f16 += as_type(short(((device short*)&v_24.i16)[0u] ^ ((device short*)&v_24.i16)[1u])); + v_24.f16 += as_type(short(((device short*)&v_24.i16)[0u] & ((device short*)&v_24.i16)[1u])); + v_24.f16 += as_type(short(((device short*)&v_24.i16)[0u] | ((device short*)&v_24.i16)[1u])); +} + +static inline __attribute__((always_inline)) +void test_u16s(device BUF0& v_24) +{ + v_24.f16s += as_type(v_24.u16s.xy + v_24.u16s.zw); + v_24.f16s += as_type(v_24.u16s.xy - v_24.u16s.zw); + v_24.f16s += as_type(v_24.u16s.xy * v_24.u16s.zw); + v_24.f16s += as_type(v_24.u16s.xy / v_24.u16s.zw); + v_24.f16s += as_type(v_24.u16s.xy % v_24.u16s.zw); + v_24.f16s += as_type(v_24.u16s.xy << v_24.u16s.zw); + v_24.f16s += as_type(v_24.u16s.xy >> v_24.u16s.zw); + v_24.f16s += as_type(~v_24.u16s.xy); + v_24.f16s += as_type(-v_24.u16s.xy); + v_24.f16s += as_type(v_24.u16s.xy ^ v_24.u16s.zw); + v_24.f16s += as_type(v_24.u16s.xy & v_24.u16s.zw); + v_24.f16s += as_type(v_24.u16s.xy | v_24.u16s.zw); +} + +static inline __attribute__((always_inline)) +void test_i16s(device BUF0& v_24) +{ + v_24.f16s += as_type(v_24.i16s.xy + v_24.i16s.zw); + v_24.f16s += as_type(v_24.i16s.xy - v_24.i16s.zw); + v_24.f16s += as_type(v_24.i16s.xy * v_24.i16s.zw); + v_24.f16s += as_type(v_24.i16s.xy / v_24.i16s.zw); + v_24.f16s += as_type(v_24.i16s.xy % v_24.i16s.zw); + v_24.f16s += as_type(v_24.i16s.xy << v_24.i16s.zw); + v_24.f16s += as_type(v_24.i16s.xy >> v_24.i16s.zw); + v_24.f16s += as_type(~v_24.i16s.xy); + v_24.f16s += as_type(-v_24.i16s.xy); + v_24.f16s += as_type(v_24.i16s.xy ^ v_24.i16s.zw); + v_24.f16s += as_type(v_24.i16s.xy & v_24.i16s.zw); + v_24.f16s += as_type(v_24.i16s.xy | v_24.i16s.zw); +} + +kernel void main0(device BUF0& v_24 [[buffer(0)]]) +{ + test_u16(v_24); + test_i16(v_24); + test_u16s(v_24); + test_i16s(v_24); +} + diff --git a/reference/shaders-msl-no-opt/comp/int16min-literal.comp b/reference/shaders-msl-no-opt/comp/int16min-literal.comp index a2b36ede..d73768c3 100644 --- a/reference/shaders-msl-no-opt/comp/int16min-literal.comp +++ b/reference/shaders-msl-no-opt/comp/int16min-literal.comp @@ -18,7 +18,7 @@ constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); kernel void main0(constant UBO& _12 [[buffer(0)]], device SSBO& _24 [[buffer(1)]]) { short v = as_type(_12.b); - v ^= short(-32768); + v = short(v ^ short(-32768)); _24.a = as_type(v); } diff --git a/shaders-msl-no-opt/comp/implicit-integer-promotion.comp b/shaders-msl-no-opt/comp/implicit-integer-promotion.comp new file mode 100644 index 00000000..a0ee95b3 --- /dev/null +++ b/shaders-msl-no-opt/comp/implicit-integer-promotion.comp @@ -0,0 +1,85 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require + +layout(set = 0, binding = 0) buffer BUF0 +{ + f16vec2 f16s; + u16vec2 u16; + i16vec2 i16; + u16vec4 u16s; + i16vec4 i16s; + float16_t f16; +}; + +void test_i16() +{ + f16 += int16BitsToFloat16(i16.x + i16.y); + f16 += int16BitsToFloat16(i16.x - i16.y); + f16 += int16BitsToFloat16(i16.x * i16.y); + f16 += int16BitsToFloat16(i16.x / i16.y); + f16 += int16BitsToFloat16(i16.x % i16.y); + f16 += int16BitsToFloat16(i16.x << i16.y); + f16 += int16BitsToFloat16(i16.x >> i16.y); + f16 += int16BitsToFloat16(~i16.x); + f16 += int16BitsToFloat16(-i16.x); + f16 += int16BitsToFloat16(i16.x ^ i16.y); + f16 += int16BitsToFloat16(i16.x & i16.y); + f16 += int16BitsToFloat16(i16.x | i16.y); +} + +void test_u16() +{ + f16 += uint16BitsToFloat16(u16.x + u16.y); + f16 += uint16BitsToFloat16(u16.x - u16.y); + f16 += uint16BitsToFloat16(u16.x * u16.y); + f16 += uint16BitsToFloat16(u16.x / u16.y); + f16 += uint16BitsToFloat16(u16.x % u16.y); + f16 += uint16BitsToFloat16(u16.x << u16.y); + f16 += uint16BitsToFloat16(u16.x >> u16.y); + f16 += uint16BitsToFloat16(~u16.x); + f16 += uint16BitsToFloat16(-u16.x); + f16 += uint16BitsToFloat16(u16.x ^ u16.y); + f16 += uint16BitsToFloat16(u16.x & u16.y); + f16 += uint16BitsToFloat16(u16.x | u16.y); +} + +void test_u16s() +{ + f16s += uint16BitsToFloat16(u16s.xy + u16s.zw); + f16s += uint16BitsToFloat16(u16s.xy - u16s.zw); + f16s += uint16BitsToFloat16(u16s.xy * u16s.zw); + f16s += uint16BitsToFloat16(u16s.xy / u16s.zw); + f16s += uint16BitsToFloat16(u16s.xy % u16s.zw); + f16s += uint16BitsToFloat16(u16s.xy << u16s.zw); + f16s += uint16BitsToFloat16(u16s.xy >> u16s.zw); + f16s += uint16BitsToFloat16(~u16s.xy); + f16s += uint16BitsToFloat16(-u16s.xy); + f16s += uint16BitsToFloat16(u16s.xy ^ u16s.zw); + f16s += uint16BitsToFloat16(u16s.xy & u16s.zw); + f16s += uint16BitsToFloat16(u16s.xy | u16s.zw); +} + +void test_i16s() +{ + f16s += int16BitsToFloat16(i16s.xy + i16s.zw); + f16s += int16BitsToFloat16(i16s.xy - i16s.zw); + f16s += int16BitsToFloat16(i16s.xy * i16s.zw); + f16s += int16BitsToFloat16(i16s.xy / i16s.zw); + f16s += int16BitsToFloat16(i16s.xy % i16s.zw); + f16s += int16BitsToFloat16(i16s.xy << i16s.zw); + f16s += int16BitsToFloat16(i16s.xy >> i16s.zw); + f16s += int16BitsToFloat16(~i16s.xy); + f16s += int16BitsToFloat16(-i16s.xy); + f16s += int16BitsToFloat16(i16s.xy ^ i16s.zw); + f16s += int16BitsToFloat16(i16s.xy & i16s.zw); + f16s += int16BitsToFloat16(i16s.xy | i16s.zw); +} + +void main() +{ + test_u16(); + test_i16(); + test_u16s(); + test_i16s(); +} diff --git a/spirv_common.hpp b/spirv_common.hpp index 5c2ad747..32f91c72 100644 --- a/spirv_common.hpp +++ b/spirv_common.hpp @@ -1796,6 +1796,33 @@ static inline bool opcode_is_sign_invariant(spv::Op opcode) } } +static inline bool opcode_can_promote_integer_implicitly(spv::Op opcode) +{ + switch (opcode) + { + case spv::OpSNegate: + case spv::OpNot: + case spv::OpBitwiseAnd: + case spv::OpBitwiseOr: + case spv::OpBitwiseXor: + case spv::OpShiftLeftLogical: + case spv::OpShiftRightLogical: + case spv::OpShiftRightArithmetic: + case spv::OpIAdd: + case spv::OpISub: + case spv::OpIMul: + case spv::OpSDiv: + case spv::OpUDiv: + case spv::OpSRem: + case spv::OpUMod: + case spv::OpSMod: + return true; + + default: + return false; + } +} + struct SetBindingPair { uint32_t desc_set; diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index a8129641..cf621904 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -5984,6 +5984,14 @@ void CompilerGLSL::emit_unary_op(uint32_t result_type, uint32_t result_id, uint3 inherit_expression_dependencies(result_id, op0); } +void CompilerGLSL::emit_unary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op) +{ + auto &type = get(result_type); + bool forward = should_forward(op0); + emit_op(result_type, result_id, join(type_to_glsl(type), "(", op, to_enclosed_unpacked_expression(op0), ")"), forward); + inherit_expression_dependencies(result_id, op0); +} + void CompilerGLSL::emit_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op) { // Various FP arithmetic opcodes such as add, sub, mul will hit this. @@ -6127,7 +6135,9 @@ bool CompilerGLSL::emit_complex_bitcast(uint32_t result_type, uint32_t id, uint3 } void CompilerGLSL::emit_binary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, - const char *op, SPIRType::BaseType input_type, bool skip_cast_if_equal_type) + const char *op, SPIRType::BaseType input_type, + bool skip_cast_if_equal_type, + bool implicit_integer_promotion) { string cast_op0, cast_op1; auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, skip_cast_if_equal_type); @@ -6136,17 +6146,23 @@ void CompilerGLSL::emit_binary_op_cast(uint32_t result_type, uint32_t result_id, // We might have casted away from the result type, so bitcast again. // For example, arithmetic right shift with uint inputs. // Special case boolean outputs since relational opcodes output booleans instead of int/uint. + auto bitop = join(cast_op0, " ", op, " ", cast_op1); string expr; - if (out_type.basetype != input_type && out_type.basetype != SPIRType::Boolean) + + if (implicit_integer_promotion) + { + // Simple value cast. + expr = join(type_to_glsl(out_type), '(', bitop, ')'); + } + else if (out_type.basetype != input_type && out_type.basetype != SPIRType::Boolean) { expected_type.basetype = input_type; - expr = bitcast_glsl_op(out_type, expected_type); - expr += '('; - expr += join(cast_op0, " ", op, " ", cast_op1); - expr += ')'; + expr = join(bitcast_glsl_op(out_type, expected_type), '(', bitop, ')'); } else - expr += join(cast_op0, " ", op, " ", cast_op1); + { + expr = std::move(bitop); + } emit_op(result_type, result_id, expr, should_forward(op0) && should_forward(op1)); inherit_expression_dependencies(result_id, op0); @@ -10751,8 +10767,10 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) #define GLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op) #define GLSL_BOP_CAST(op, type) \ - emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode)) + emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, \ + opcode_is_sign_invariant(opcode), implicit_integer_promotion) #define GLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op) +#define GLSL_UOP_CAST(op) emit_unary_op_cast(ops[0], ops[1], ops[2], #op) #define GLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op) #define GLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op) #define GLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op) @@ -10766,6 +10784,13 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) auto int_type = to_signed_basetype(integer_width); auto uint_type = to_unsigned_basetype(integer_width); + // Handle C implicit integer promotion rules. + // If we get implicit promotion to int, need to make sure we cast by value to intended return type, + // otherwise, future sign-dependent operations and bitcasts will break. + bool implicit_integer_promotion = integer_width < 32 && backend.implicit_c_integer_promotion_rules && + opcode_can_promote_integer_implicitly(opcode) && + get(ops[0]).vecsize == 1; + opcode = get_remapped_spirv_op(opcode); switch (opcode) @@ -11600,6 +11625,12 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) break; case OpSNegate: + if (implicit_integer_promotion || expression_type_id(ops[2]) != ops[0]) + GLSL_UOP_CAST(-); + else + GLSL_UOP(-); + break; + case OpFNegate: GLSL_UOP(-); break; @@ -11744,6 +11775,9 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) auto expr = join(to_enclosed_expression(op0), " - ", to_enclosed_expression(op1), " * ", "(", to_enclosed_expression(op0), " / ", to_enclosed_expression(op1), ")"); + if (implicit_integer_promotion) + expr = join(type_to_glsl(get(result_type)), '(', expr, ')'); + emit_op(result_type, result_id, expr, forward); inherit_expression_dependencies(result_id, op0); inherit_expression_dependencies(result_id, op1); @@ -11841,7 +11875,10 @@ void CompilerGLSL::emit_instruction(const Instruction &instruction) } case OpNot: - GLSL_UOP(~); + if (implicit_integer_promotion || expression_type_id(ops[2]) != ops[0]) + GLSL_UOP_CAST(~); + else + GLSL_UOP(~); break; case OpUMod: diff --git a/spirv_glsl.hpp b/spirv_glsl.hpp index 2d1dad6c..15ffac29 100644 --- a/spirv_glsl.hpp +++ b/spirv_glsl.hpp @@ -619,6 +619,7 @@ protected: bool support_64bit_switch = false; bool workgroup_size_is_hidden = false; bool requires_relaxed_precision_analysis = false; + bool implicit_c_integer_promotion_rules = false; } backend; void emit_struct(SPIRType &type); @@ -691,7 +692,7 @@ protected: void emit_unrolled_binary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op, bool negate, SPIRType::BaseType expected_type); void emit_binary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op, - SPIRType::BaseType input_type, bool skip_cast_if_equal_type); + SPIRType::BaseType input_type, bool skip_cast_if_equal_type, bool implicit_integer_promotion); SPIRType binary_op_bitcast_helper(std::string &cast_op0, std::string &cast_op1, SPIRType::BaseType &input_type, uint32_t op0, uint32_t op1, bool skip_cast_if_equal_type); @@ -702,6 +703,7 @@ protected: uint32_t false_value); void emit_unary_op(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op); + void emit_unary_op_cast(uint32_t result_type, uint32_t result_id, uint32_t op0, const char *op); bool expression_is_forwarded(uint32_t id) const; bool expression_suppresses_usage_tracking(uint32_t id) const; bool expression_read_implies_multiple_reads(uint32_t id) const; diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index 1291f7ee..65c9882b 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -4965,7 +4965,7 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) #define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op) #define HLSL_BOP_CAST(op, type) \ - emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode)) + emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false) #define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op) #define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op) #define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op) diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 11aefe64..58090ebb 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -1439,6 +1439,7 @@ string CompilerMSL::compile() // Arrays which are part of buffer objects are never considered to be value types (just plain C-style). backend.array_is_value_type_in_buffer_blocks = false; backend.support_pointer_to_pointer = true; + backend.implicit_c_integer_promotion_rules = true; capture_output_to_buffer = msl_options.capture_output_to_buffer; is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer; @@ -8167,8 +8168,9 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) { #define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op) #define MSL_PTR_BOP(op) emit_binary_ptr_op(ops[0], ops[1], ops[2], ops[3], #op) + // MSL does care about implicit integer promotion, but those cases are all handled in common code. #define MSL_BOP_CAST(op, type) \ - emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode)) + emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode), false) #define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op) #define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op) #define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op) -- cgit v1.2.3