diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-22 19:04:22 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-09-30 16:05:00 +0300 |
commit | c6851d1a9c8cab163cd86e539bc1fa42a52bf823 (patch) | |
tree | 4e1050b5b91ec9457481c5c7b34a10c931035059 | |
parent | 30c7e3ab2d11723977ee2402c28f75b92280650b (diff) |
Add callbacks for SSRUIntegermarian-ssru
-rw-r--r-- | callbacks/configs.h | 21 | ||||
-rw-r--r-- | callbacks/implementations.inl | 101 | ||||
-rw-r--r-- | intrinsics.h | 9 |
3 files changed, 131 insertions, 0 deletions
diff --git a/callbacks/configs.h b/callbacks/configs.h index 8e2eacc..4dfee71 100644 --- a/callbacks/configs.h +++ b/callbacks/configs.h @@ -35,5 +35,26 @@ struct UnquantizeAndAddBiasAndWrite { UnquantizeAndAddBiasAndWrite(float unquant_mult, const float* bias_addr, float* output_addr) : unquant_mult(unquant_mult), bias_addr(bias_addr), output_addr(output_addr) {} }; +template <typename Type> +struct SSRUSigmoidF { + const Type* bias_addr; + const Type* sigmoid_lut; + float quant_mult_f; + float quant_mult_bf; + float sigmoid_lut_range; + Type* output_addr; + + SSRUSigmoidF(const Type* bias_addr, const Type* sigmoid_lut, float quant_mult_f, float quant_mult_bf, float sigmoid_lut_range, Type* output_addr) : bias_addr(bias_addr), sigmoid_lut(sigmoid_lut), quant_mult_f(quant_mult_f), quant_mult_bf(quant_mult_bf), sigmoid_lut_range(sigmoid_lut_range), output_addr(output_addr) {} +}; + +template <typename Type> +struct SSRUPrecomputedPartOfHighway { + const Type* sigmoid_f_addr; + float scale; + Type* output_addr; + + SSRUPrecomputedPartOfHighway(const Type* sigmoid_f_addr, float scale, Type* output_addr) : sigmoid_f_addr(sigmoid_f_addr), scale(scale), output_addr(output_addr) {} +}; + } } diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl index f80b2ed..74fa656 100644 --- a/callbacks/implementations.inl +++ b/callbacks/implementations.inl @@ -110,16 +110,117 @@ public: CPU_ATTR CallbackImpl(const UnquantizeAndAddBiasAndWrite& config) : config(config) { unquant_mult = set1_ps<vf>(config.unquant_mult); } + CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) { auto result = kernels::unquantize(input, unquant_mult); result = kernels::add_bias(result, config.bias_addr, info.col_idx); kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); } + private: UnquantizeAndAddBiasAndWrite config; vf unquant_mult; }; +/* + * SSRUSigmoidF + * + * output = sigmoid_lut(scale(input + bias), scale)) + */ +template <> class CallbackImpl<CPUType::CPU_NAME, SSRUSigmoidF<int8_t>> { +public: + CPU_ATTR CallbackImpl(const SSRUSigmoidF<int8_t>& config) : config(config), buffered_inputs_n(0), buffered_info(0, 0, 0, 0) { + scale = set1_ps<vf>(config.quant_mult_bf / config.quant_mult_f); + scale2 = set1_ps<vf>((127.0f / config.sigmoid_lut_range) / config.quant_mult_bf); + } + + // Workaround. If the buffer size is not aligned to 4xsizeof(vec) then there'll be a problem with tails. + CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) { + buffered_inputs[buffered_inputs_n++] = input; + if (buffered_inputs_n == 1) + buffered_info = info; + else if (buffered_inputs_n == 4) { + callback(buffered_inputs[0], buffered_inputs[1], buffered_inputs[2], buffered_inputs[3], buffered_info); + buffered_inputs_n = 0; + } + } + +private: + SSRUSigmoidF<int8_t> config; + vf scale; + vf scale2; + + int buffered_inputs_n; + vi buffered_inputs[4]; + OutputBufferInfo buffered_info; + + CPU_ATTR void callback(vi input1, vi input2, vi input3, vi input4, const OutputBufferInfo& info) { + auto result = kernels::downcast32to8( + kernels::rescale(input1, scale), + kernels::rescale(input2, scale), + kernels::rescale(input3, scale), + kernels::rescale(input4, scale)); + result = kernels::add_bias(result, config.bias_addr, info.col_idx); + + auto tmp = kernels::upcast8to32(result); + result = kernels::downcast32to8( + kernels::rescale(tmp.first, scale2), + kernels::rescale(tmp.second, scale2), + kernels::rescale(tmp.third, scale2), + kernels::rescale(tmp.fourth, scale2)); + + result = kernels::lookup_8b(result, config.sigmoid_lut); + kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); + } +}; + +/* + * SSRUPrecomputedPartOfHighway + * + * output = (1 - sigmoid) * input + */ +template <> class CallbackImpl<CPUType::CPU_NAME, SSRUPrecomputedPartOfHighway<int8_t>> { +public: + CPU_ATTR CallbackImpl(const SSRUPrecomputedPartOfHighway<int8_t>& config) : config(config), buffered_inputs_n(0), buffered_info(0, 0, 0, 0) { + scale = set1_ps<vf>(config.scale); + } + + // Workaround. If the buffer size is not aligned to 4xsizeof(vec) then there'll be a problem with tails. + CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) { + buffered_inputs[buffered_inputs_n++] = input; + if (buffered_inputs_n == 1) + buffered_info = info; + else if (buffered_inputs_n == 4) { + callback(buffered_inputs[0], buffered_inputs[1], buffered_inputs[2], buffered_inputs[3], buffered_info); + buffered_inputs_n = 0; + } + } + +private: + SSRUPrecomputedPartOfHighway<int8_t> config; + vf scale; + + int buffered_inputs_n; + vi buffered_inputs[4]; + OutputBufferInfo buffered_info; + + CPU_ATTR void callback(vi input1, vi input2, vi input3, vi input4, const OutputBufferInfo& info) { + // TODO: Use 255 for better resolution (it needs u8 intrinsics) + static const auto vconst_int8_max = set1_epi8<vi>(127); + + const auto offset = info.row_idx * info.cols + info.col_idx; + const auto sigmoid = *reinterpret_cast<const vi*>(config.sigmoid_f_addr + offset); + + auto result = kernels::downcast32to8( + kernels::rescale(input1, scale), + kernels::rescale(input2, scale), + kernels::rescale(input3, scale), + kernels::rescale(input4, scale)); + result = kernels::multiply_sat<int8_t>(sub_epi8(vconst_int8_max, sigmoid), result, 7); + kernels::write(result, config.output_addr, offset); + } +}; + } } diff --git a/intrinsics.h b/intrinsics.h index 4204d05..7314c82 100644 --- a/intrinsics.h +++ b/intrinsics.h @@ -173,6 +173,9 @@ INTGEMM_SSE2 static inline void storeu_ps(float* mem_addr, __m128 a) { INTGEMM_SSE2 static inline void storeu_si(__m128i* mem_addr, __m128i a) { _mm_storeu_si128(mem_addr, a); } +INTGEMM_SSE2 static inline __m128i sub_epi8(__m128i a, __m128i b) { + return _mm_sub_epi8(a, b); +} INTGEMM_SSE2 static inline __m128d sub_pd(__m128d a, __m128d b) { return _mm_sub_pd(a, b); } @@ -345,6 +348,9 @@ INTGEMM_AVX2 static inline void storeu_ps(float* mem_addr, __m256 a) { INTGEMM_AVX2 static inline void storeu_si(__m256i* mem_addr, __m256i a) { _mm256_storeu_si256(mem_addr, a); } +INTGEMM_AVX2 static inline __m256i sub_epi8(__m256i a, __m256i b) { + return _mm256_sub_epi8(a, b); +} INTGEMM_AVX2 static inline __m256d sub_pd(__m256d a, __m256d b) { return _mm256_sub_pd(a, b); } @@ -519,6 +525,9 @@ INTGEMM_AVX512BW static inline void storeu_ps(void* mem_addr, __m512 a) { INTGEMM_AVX512BW static inline void storeu_si(void* mem_addr, __m512i a) { _mm512_storeu_si512(mem_addr, a); } +INTGEMM_AVX512BW static inline __m512i sub_epi8(__m512i a, __m512i b) { + return _mm512_sub_epi8(a, b); +} INTGEMM_AVX512BW static inline __m512d sub_pd(__m512d a, __m512d b) { return _mm512_sub_pd(a, b); } |