Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-22 19:04:22 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-09-30 16:05:00 +0300
commitc6851d1a9c8cab163cd86e539bc1fa42a52bf823 (patch)
tree4e1050b5b91ec9457481c5c7b34a10c931035059
parent30c7e3ab2d11723977ee2402c28f75b92280650b (diff)
Add callbacks for SSRUIntegermarian-ssru
-rw-r--r--callbacks/configs.h21
-rw-r--r--callbacks/implementations.inl101
-rw-r--r--intrinsics.h9
3 files changed, 131 insertions, 0 deletions
diff --git a/callbacks/configs.h b/callbacks/configs.h
index 8e2eacc..4dfee71 100644
--- a/callbacks/configs.h
+++ b/callbacks/configs.h
@@ -35,5 +35,26 @@ struct UnquantizeAndAddBiasAndWrite {
UnquantizeAndAddBiasAndWrite(float unquant_mult, const float* bias_addr, float* output_addr) : unquant_mult(unquant_mult), bias_addr(bias_addr), output_addr(output_addr) {}
};
+template <typename Type>
+struct SSRUSigmoidF {
+ const Type* bias_addr;
+ const Type* sigmoid_lut;
+ float quant_mult_f;
+ float quant_mult_bf;
+ float sigmoid_lut_range;
+ Type* output_addr;
+
+ SSRUSigmoidF(const Type* bias_addr, const Type* sigmoid_lut, float quant_mult_f, float quant_mult_bf, float sigmoid_lut_range, Type* output_addr) : bias_addr(bias_addr), sigmoid_lut(sigmoid_lut), quant_mult_f(quant_mult_f), quant_mult_bf(quant_mult_bf), sigmoid_lut_range(sigmoid_lut_range), output_addr(output_addr) {}
+};
+
+template <typename Type>
+struct SSRUPrecomputedPartOfHighway {
+ const Type* sigmoid_f_addr;
+ float scale;
+ Type* output_addr;
+
+ SSRUPrecomputedPartOfHighway(const Type* sigmoid_f_addr, float scale, Type* output_addr) : sigmoid_f_addr(sigmoid_f_addr), scale(scale), output_addr(output_addr) {}
+};
+
}
}
diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl
index f80b2ed..74fa656 100644
--- a/callbacks/implementations.inl
+++ b/callbacks/implementations.inl
@@ -110,16 +110,117 @@ public:
CPU_ATTR CallbackImpl(const UnquantizeAndAddBiasAndWrite& config) : config(config) {
unquant_mult = set1_ps<vf>(config.unquant_mult);
}
+
CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
auto result = kernels::unquantize(input, unquant_mult);
result = kernels::add_bias(result, config.bias_addr, info.col_idx);
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}
+
private:
UnquantizeAndAddBiasAndWrite config;
vf unquant_mult;
};
+/*
+ * SSRUSigmoidF
+ *
+ * output = sigmoid_lut(scale(input + bias), scale))
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, SSRUSigmoidF<int8_t>> {
+public:
+ CPU_ATTR CallbackImpl(const SSRUSigmoidF<int8_t>& config) : config(config), buffered_inputs_n(0), buffered_info(0, 0, 0, 0) {
+ scale = set1_ps<vf>(config.quant_mult_bf / config.quant_mult_f);
+ scale2 = set1_ps<vf>((127.0f / config.sigmoid_lut_range) / config.quant_mult_bf);
+ }
+
+ // Workaround. If the buffer size is not aligned to 4xsizeof(vec) then there'll be a problem with tails.
+ CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ buffered_inputs[buffered_inputs_n++] = input;
+ if (buffered_inputs_n == 1)
+ buffered_info = info;
+ else if (buffered_inputs_n == 4) {
+ callback(buffered_inputs[0], buffered_inputs[1], buffered_inputs[2], buffered_inputs[3], buffered_info);
+ buffered_inputs_n = 0;
+ }
+ }
+
+private:
+ SSRUSigmoidF<int8_t> config;
+ vf scale;
+ vf scale2;
+
+ int buffered_inputs_n;
+ vi buffered_inputs[4];
+ OutputBufferInfo buffered_info;
+
+ CPU_ATTR void callback(vi input1, vi input2, vi input3, vi input4, const OutputBufferInfo& info) {
+ auto result = kernels::downcast32to8(
+ kernels::rescale(input1, scale),
+ kernels::rescale(input2, scale),
+ kernels::rescale(input3, scale),
+ kernels::rescale(input4, scale));
+ result = kernels::add_bias(result, config.bias_addr, info.col_idx);
+
+ auto tmp = kernels::upcast8to32(result);
+ result = kernels::downcast32to8(
+ kernels::rescale(tmp.first, scale2),
+ kernels::rescale(tmp.second, scale2),
+ kernels::rescale(tmp.third, scale2),
+ kernels::rescale(tmp.fourth, scale2));
+
+ result = kernels::lookup_8b(result, config.sigmoid_lut);
+ kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
+ }
+};
+
+/*
+ * SSRUPrecomputedPartOfHighway
+ *
+ * output = (1 - sigmoid) * input
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, SSRUPrecomputedPartOfHighway<int8_t>> {
+public:
+ CPU_ATTR CallbackImpl(const SSRUPrecomputedPartOfHighway<int8_t>& config) : config(config), buffered_inputs_n(0), buffered_info(0, 0, 0, 0) {
+ scale = set1_ps<vf>(config.scale);
+ }
+
+ // Workaround. If the buffer size is not aligned to 4xsizeof(vec) then there'll be a problem with tails.
+ CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ buffered_inputs[buffered_inputs_n++] = input;
+ if (buffered_inputs_n == 1)
+ buffered_info = info;
+ else if (buffered_inputs_n == 4) {
+ callback(buffered_inputs[0], buffered_inputs[1], buffered_inputs[2], buffered_inputs[3], buffered_info);
+ buffered_inputs_n = 0;
+ }
+ }
+
+private:
+ SSRUPrecomputedPartOfHighway<int8_t> config;
+ vf scale;
+
+ int buffered_inputs_n;
+ vi buffered_inputs[4];
+ OutputBufferInfo buffered_info;
+
+ CPU_ATTR void callback(vi input1, vi input2, vi input3, vi input4, const OutputBufferInfo& info) {
+ // TODO: Use 255 for better resolution (it needs u8 intrinsics)
+ static const auto vconst_int8_max = set1_epi8<vi>(127);
+
+ const auto offset = info.row_idx * info.cols + info.col_idx;
+ const auto sigmoid = *reinterpret_cast<const vi*>(config.sigmoid_f_addr + offset);
+
+ auto result = kernels::downcast32to8(
+ kernels::rescale(input1, scale),
+ kernels::rescale(input2, scale),
+ kernels::rescale(input3, scale),
+ kernels::rescale(input4, scale));
+ result = kernels::multiply_sat<int8_t>(sub_epi8(vconst_int8_max, sigmoid), result, 7);
+ kernels::write(result, config.output_addr, offset);
+ }
+};
+
}
}
diff --git a/intrinsics.h b/intrinsics.h
index 4204d05..7314c82 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -173,6 +173,9 @@ INTGEMM_SSE2 static inline void storeu_ps(float* mem_addr, __m128 a) {
INTGEMM_SSE2 static inline void storeu_si(__m128i* mem_addr, __m128i a) {
_mm_storeu_si128(mem_addr, a);
}
+INTGEMM_SSE2 static inline __m128i sub_epi8(__m128i a, __m128i b) {
+ return _mm_sub_epi8(a, b);
+}
INTGEMM_SSE2 static inline __m128d sub_pd(__m128d a, __m128d b) {
return _mm_sub_pd(a, b);
}
@@ -345,6 +348,9 @@ INTGEMM_AVX2 static inline void storeu_ps(float* mem_addr, __m256 a) {
INTGEMM_AVX2 static inline void storeu_si(__m256i* mem_addr, __m256i a) {
_mm256_storeu_si256(mem_addr, a);
}
+INTGEMM_AVX2 static inline __m256i sub_epi8(__m256i a, __m256i b) {
+ return _mm256_sub_epi8(a, b);
+}
INTGEMM_AVX2 static inline __m256d sub_pd(__m256d a, __m256d b) {
return _mm256_sub_pd(a, b);
}
@@ -519,6 +525,9 @@ INTGEMM_AVX512BW static inline void storeu_ps(void* mem_addr, __m512 a) {
INTGEMM_AVX512BW static inline void storeu_si(void* mem_addr, __m512i a) {
_mm512_storeu_si512(mem_addr, a);
}
+INTGEMM_AVX512BW static inline __m512i sub_epi8(__m512i a, __m512i b) {
+ return _mm512_sub_epi8(a, b);
+}
INTGEMM_AVX512BW static inline __m512d sub_pd(__m512d a, __m512d b) {
return _mm512_sub_pd(a, b);
}