Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ssse3_gemm.cc89
-rw-r--r--ssse3_gemm.h30
2 files changed, 119 insertions, 0 deletions
diff --git a/ssse3_gemm.cc b/ssse3_gemm.cc
new file mode 100644
index 0000000..7a6cd24
--- /dev/null
+++ b/ssse3_gemm.cc
@@ -0,0 +1,89 @@
+#include "ssse3_gemm.h"
+
+#include "interleave.h"
+#include "multiply.h"
+
+#include <stdint.h>
+#include <cassert>
+#include <xmmintrin.h>
+#include <emmintrin.h>
+
+namespace intgemm {
+
+#ifdef __SSSE3__
+
+namespace {
+// Same implementation as AVX512, just width. Grabs 4 32-bit values.
+inline __m128i QuantizerGrab(const float *input, const __m128 quant_mult_reg) {
+ return _mm_cvtps_epi32(_mm_mul_ps(*reinterpret_cast<const __m128*>(input), quant_mult_reg));
+}
+
+class QuantizeTile8 {
+ public:
+ typedef __m128i Integer;
+
+ explicit QuantizeTile8(float mult) : mult_reg_(_mm_set1_ps(mult)) {}
+
+ inline __m128i ForReshape(const float *input, int cols) {
+ // Skip a row.
+ return Tile(input, input + 2 * cols);
+ }
+
+ inline __m128i Consecutive(const float *input) {
+ return Tile(input, input + 8);
+ }
+
+ private:
+ // Quantize 16xfloat into 16xint8_t
+ inline __m128i Tile(const float *input0, const float *input1) {
+ const __m128i neg128 = _mm_set1_epi8(-128);
+ __m128i g0 = QuantizerGrab(input0, mult_reg_);
+ __m128i g1 = QuantizerGrab(input0 + 4, mult_reg_);
+ __m128i g2 = QuantizerGrab(input1, mult_reg_);
+ __m128i g3 = QuantizerGrab(input1 + 4, mult_reg_);
+ __m128i packed0 = _mm_packs_epi32(g0, g1);
+ __m128i packed1 = _mm_packs_epi32(g2, g3);
+ __m128i packed = _mm_packs_epi16(packed0, packed1);
+ /* Ban -128.
+ * Don't use the SSE4.1 instruction _mm_max_epi8(packed, neg127). Instead,
+ * use SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8.
+ * The first generates 0xff for fields -128.
+ * The second subtracts 0xff from -128 which has the effect of converting
+ * to -127.
+ */
+ // packed = _mm_max_epi8(packed, neg127);
+ __m128i evils = _mm_cmpeq_epi8(packed, neg128);
+ return _mm_sub_epi8(packed, evils);
+ // No permute needed. packs is in order for SSE.
+ }
+
+ private:
+ const __m128 mult_reg_;
+};
+
+} // namespace
+
+void SSSE3_8bit::Quantize(const float *input, int8_t *output, float quant_mult, int size) {
+ assert(size % 16 == 0);
+ assert(reinterpret_cast<uintptr_t>(input) % 16 == 0);
+ assert(reinterpret_cast<uintptr_t>(output) % 16 == 0);
+ QuantizeTile8 q(quant_mult);
+ const float *end = input + size;
+ for (; input != end; input += 16, output += 16) {
+ *reinterpret_cast<__m128i*>(output) = q.Consecutive(input);
+ }
+}
+
+void SSSE3_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
+ PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols);
+}
+
+void SSSE3_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+ Multiply8_SSE2OrAVX2<__m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols);
+}
+
+const char *const SSSE3_8bit::kName = "8-bit SSSE3";
+
+#endif // __SSSE3__
+
+} // namespace intgemm
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
new file mode 100644
index 0000000..afbe4f0
--- /dev/null
+++ b/ssse3_gemm.h
@@ -0,0 +1,30 @@
+#pragma once
+#include <stdint.h>
+
+// 16-bit is in sse2_gemm.h
+
+namespace intgemm {
+
+// pmaddubsw (the 8-bit multiply) is SSSE3, so pedantically that's the version we need.
+struct SSSE3_8bit {
+ typedef int8_t Integer;
+
+ // Currently A is prepared by quantization but this could theoretically change.
+ static inline void PrepareA(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
+ Quantize(input, output, quant_mult, rows * cols);
+ }
+
+ static void Quantize(const float *input, int8_t *output, float quant_mult, int size);
+
+ // Tile size for B; B must be a multiple of this block size.
+ static const int kBTileRow = 16;
+ static const int kBTileCol = 8;
+
+ static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols);
+
+ static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
+
+ static const char *const kName;
+};
+
+} // namespace intgemm