Welcome to mirror list, hosted at ThFree Co, Russian Federation.

ssse3_gemm.h - github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 584154126fd8c8b4862da6f47d79256fd3946120 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#pragma once

#include "interleave.h"
#include "kernels.h"
#include "multiply.h"
#include "types.h"

#include <cstdint>
#include <stdint.h>
#include <cstring>

// 16-bit is in sse2_gemm.h

namespace intgemm {

namespace ssse3 {

INTGEMM_SSSE3 inline __m128i QuantizerGrab(const float *input, const __m128 quant_mult_reg) {
  return kernels::quantize(loadu_ps<__m128>(input), quant_mult_reg);
}

INTGEMM_SELECT_COL_B(INTGEMM_SSSE3, __m128i)
INTGEMM_SELECT_COL_B_COLUMN_MAJOR(INTGEMM_SSSE3, __m128i)

class QuantizeTile8 {
  public:
    typedef __m128i Register;

    INTGEMM_SSSE3 explicit QuantizeTile8(float mult) : mult_reg_(_mm_set1_ps(mult)) {}

    INTGEMM_SSSE3 inline __m128i ForReshape(const float *input, Index cols) const {
      // Skip a row.
      return Tile(input, input + 4, input + 2 * cols, input + 2 * cols + 4);
    }

    INTGEMM_SSSE3 inline __m128i Consecutive(const float *input) const {
      return Tile(input, input + 4, input + 8, input + 12);
    }

    INTGEMM_SSSE3 inline __m128i ConsecutiveU(const float *input) const {
      return TileU(input, input + 4, input + 8, input + 12);
    }

    INTGEMM_SSSE3 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const {
      const float* inputs[4];
      for (Index i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) {
        while (cols_left < sizeof(Register) / sizeof(float)) {
          input += cols * (row_step - 1);
          cols_left += cols;
        }
        inputs[i] = input;
        input += sizeof(Register) / sizeof(float);
        cols_left -= sizeof(Register) / sizeof(float);
      }
      return Tile(inputs[0], inputs[1], inputs[2], inputs[3]);
    }

    // Quantize 16xfloat into 16xint8_t
    INTGEMM_SSSE3 inline __m128i Tile(const float *input0, const float *input1, const float *input2, const float *input3) const {
      const __m128i neg128 = _mm_set1_epi8(-128);
      __m128i g0 = QuantizerGrab(input0, mult_reg_);
      __m128i g1 = QuantizerGrab(input1, mult_reg_);
      __m128i g2 = QuantizerGrab(input2, mult_reg_);
      __m128i g3 = QuantizerGrab(input3, mult_reg_);
      __m128i packed0 = _mm_packs_epi32(g0, g1);
      __m128i packed1 = _mm_packs_epi32(g2, g3);
      __m128i packed = _mm_packs_epi16(packed0, packed1);
      /* Ban -128.
       * Don't use the SSE4.1 instruction _mm_max_epi8(packed, neg127).  Instead,
       * use SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8.
       * The first generates 0xff for fields -128.
       * The second subtracts 0xff from -128 which has the effect of converting
       * to -127.
       */
      // packed = _mm_max_epi8(packed, neg127);
      __m128i evils = _mm_cmpeq_epi8(packed, neg128);
      return _mm_sub_epi8(packed, evils);
      // No permute needed.  packs is in order for SSE.
    }

  private:
    INTGEMM_SSSE3 inline __m128i TileU(const float *input0, const float *input1, const float *input2, const float *input3) const {
      const __m128i neg128 = _mm_set1_epi8(-128);
      const __m128i pos127 = _mm_set1_epi8(127);
      __m128i g0 = QuantizerGrab(input0, mult_reg_);
      __m128i g1 = QuantizerGrab(input1, mult_reg_);
      __m128i g2 = QuantizerGrab(input2, mult_reg_);
      __m128i g3 = QuantizerGrab(input3, mult_reg_);
      __m128i packed0 = _mm_packs_epi32(g0, g1);
      __m128i packed1 = _mm_packs_epi32(g2, g3);
      __m128i packed = _mm_packs_epi16(packed0, packed1);
      /* Ban -128.
       * Don't use the SSE4.1 instruction _mm_max_epi8(packed, neg127).  Instead,
       * use SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8.
       * The first generates 0xff for fields -128.
       * The second subtracts 0xff from -128 which has the effect of converting
       * to -127.
       */
      // packed = _mm_max_epi8(packed, neg127);
      __m128i evils = _mm_cmpeq_epi8(packed, neg128);
      return _mm_add_epi8(_mm_sub_epi8(packed, evils), pos127);
      // No permute needed.  packs is in order for SSE.
    }

  private:
    const __m128 mult_reg_;
};

} // namespace

// pmaddubsw (the 8-bit multiply) is INTGEMM_SSSE3, so pedantically that's the version we need.
struct SSSE3_8bit {
  typedef int8_t Integer;

  // Currently A is prepared by quantization but this could theoretically change.
  INTGEMM_SSSE3 static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
    Quantize(input, output, quant_mult, rows * cols);
  }

 private:
  INTGEMM_QUANTIZE_THREAD(INTGEMM_SSSE3, __m128i, ssse3)
 public:
  INTGEMM_QUANTIZE(INTGEMM_SSSE3, __m128i, ssse3)

  // Version with unsigned int + 127
  // Currently A is prepared by quantization but this could theoretically change.
  INTGEMM_SSSE3 static inline void PrepareA(const float *input, uint8_t *output, float quant_mult, Index rows, Index cols) {
    QuantizeU(input, output, quant_mult, rows * cols);
  }

  INTGEMM_SSSE3 static void QuantizeU(const float *input, uint8_t *output, float quant_mult, Index size) {
    assert(size % 16 == 0);
    assert(reinterpret_cast<uintptr_t>(input) % 16 == 0);
    assert(reinterpret_cast<uintptr_t>(output) % 16 == 0);
    ssse3::QuantizeTile8 q(quant_mult);
    const float *end = input + size;
    for (; input != end; input += 16, output += 16) {
      *reinterpret_cast<__m128i*>(output) = q.ConsecutiveU(input);
    }
  }

  // Tile size for B; B must be a multiple of this block size.
  static const Index kBTileRow = 16;
  static const Index kBTileCol = 8;

  INTGEMM_PREPARE_B_8(INTGEMM_SSSE3, ssse3::QuantizeTile8)
  INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_SSSE3, CPUType::SSE2, int8_t)
  INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_SSSE3, ssse3::QuantizeTile8, int8_t)

  INTGEMM_SSSE3 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
    ssse3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
  }
  INTGEMM_SSSE3 static void SelectColumnsB_ColumnsMajor(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
    ssse3::SelectColumnsOfB_ColumnMajor((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
  }

  INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::SSE2)

  INTGEMM_MULTIPLY8SHIFT(__m128i, INTGEMM_SSSE3, CPUType::SSE2)

  INTGEMM_PREPAREBIASFOR8(__m128i, INTGEMM_SSSE3, CPUType::SSE2)

  constexpr static const char *const kName = "8-bit SSSE3";

  static const CPUType kUses = CPUType::SSSE3;
};

} // namespace intgemm