Welcome to mirror list, hosted at ThFree Co, Russian Federation.

multiply.h - github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: e27f31faaa3b736418e5cb5a26a351166a42fde3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
#pragma once

#include "intrinsics.h"

namespace intgemm {

SSE2 static inline float MaxFloat32(__m128 a) {
  // Fold to just using the first 64 bits.
  __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
  a = _mm_max_ps(a, second_half);
  // Fold to just using the first 32 bits.
  second_half = _mm_shuffle_ps(a, a, 1);
  a = _mm_max_ps(a, second_half);
  // This casting compiles to nothing.
  return *reinterpret_cast<float*>(&a);
}

SSE2 static inline MultiplyResult128 PermuteSummer(__m128i pack0123, __m128i pack4567) {
  // No op for 128 bits: already reduced fully.
  MultiplyResult128 ret;
  ret.pack0123 = pack0123;
  ret.pack4567 = pack4567;
  return ret;
}

// Complete any reduction, multiply by scaling, and write to memory.
SSE2 static inline void WriteC(float *to, MultiplyResult128 total, __m128 unquant_reg) {
  // Convert to float, multiply by unquant, and write.
  *reinterpret_cast<__m128*>(to) = mul_ps(cvtepi32_ps(total.pack0123), unquant_reg);
  *reinterpret_cast<__m128*>(to + 4) = mul_ps(cvtepi32_ps(total.pack4567), unquant_reg);
}

AVX2 static inline float MaxFloat32(__m256 a) {
  return MaxFloat32(max_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
}

AVX2 static inline __m256i PermuteSummer(__m256i pack0123, __m256i pack4567) {
  // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f
  __m256i rev = _mm256_permute2f128_si256(pack0123, pack4567, 0x21);
  // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s
  __m256i blended = _mm256_blend_epi32(pack0123, pack4567, 0xf0);
  return _mm256_add_epi32(rev, blended);
}

AVX2 static inline void WriteC(float *to, __m256i total, __m256 unquant_reg) {
  // Convert to float, multiply by unquant, and write.
  *reinterpret_cast<__m256*>(to) = mul_ps(cvtepi32_ps(total), unquant_reg);
}

#ifndef INTGEMM_NO_AVX512

AVX512F static inline __m256i PermuteSummer(__m512i pack0123, __m512i pack4567) {
  // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567, 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567]
  __m512i mix0 = _mm512_mask_permutex_epi64(pack0123, 0xcc, pack4567, (0 << 4) | (1 << 6));
  // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567, 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567]
  __m512i mix1 = _mm512_mask_permutex_epi64(pack4567, 0x33, pack0123, 2 | (3 << 2));
  __m512i added = _mm512_add_epi32(mix0, mix1);
  // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7.
  // Fold register over itself.
  return _mm256_add_epi32(_mm512_castsi512_si256(added), _mm512_extracti64x4_epi64(added, 1));
}

// Find the maximum float.
static inline AVX512DQ float MaxFloat32(__m512 a) {
  return MaxFloat32(max_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)));
}

#endif

/* Take 4 registers with 32-bit values to be horizontally added.  Reduce them
 * to one register with 32-bit values in the pattern 1 2 3 4 1 2 3 4, leaving
 * the final addition (which crosses 128-bit lanes) to the caller. 
template <class Register> inline Register Pack0123(Register sum0, Register sum1, Register sum2, Register sum3) {
  // 1 2 1 2 1 2 1 2
  Interleave32(sum0, sum1);
  Register pack01 = add_epi32(sum0, sum1);
  // 3 4 3 4 3 4 3 4
  Interleave32(sum2, sum3);
  Register pack23 = add_epi32(sum2, sum3);
  Interleave64(pack01, pack23);
  // 1 2 3 4 1 2 3 4
  return add_epi32(pack01, pack23);
}
 */
#define PACK_DEFINE(target, Register) \
target inline Register Pack0123(Register sum0, Register sum1, Register sum2, Register sum3) { \
  Interleave32(sum0, sum1); \
  Register pack01 = add_epi32(sum0, sum1); \
  Interleave32(sum2, sum3); \
  Register pack23 = add_epi32(sum2, sum3); \
  Interleave64(pack01, pack23); \
  return add_epi32(pack01, pack23); \
} \

PACK_DEFINE(SSE2, __m128i)
PACK_DEFINE(AVX2, __m256i)
#ifndef INTGEMM_NO_AVX512
PACK_DEFINE(AVX512F, __m512i)
#endif

// 16-bit multiplier for SSE2, AVX2, and AVX512.
// C = A * B * unquant_mult
//
// This has been substantially revised from Jacob Devlin's SSE code which is:
// Copyright (c) 2017 Microsoft Corporation

// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:

// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

// A is a row-major quantized matrix (from PrepareA)
// B is a rearranged quantized matrix (from PrepareB)
// C is output in row-major form.
//
// All of A, B, and C must be in aligned to a multiple of the register size:
// SSE2: 16 bytes
// AVX2: 32 bytes
// AVX512: 64 bytes.
//
// A_rows can be anything non-negative.
// width must be a multiple of the register size.
// B_cols must be a multiple of 8.
// Multiply16
#define MULTIPLY16_define(Integer, target) \
  template <class WriteC> target static void Multiply(const int16_t *A, const int16_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \
  assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0); \
  assert(B_cols % 8 == 0); \
  assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
  assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
  /*assert(reinterpret_cast<uintptr_t>(C) % sizeof(Integer) == 0); Moved to WriteC*/ \
  const int simd_width = width / (sizeof(Integer) / sizeof(int16_t)); \
  /*const Float unquant_reg = set1_ps<Float>(unquant_mult); moved to WriteC*/ \
  const Integer *B0_col = reinterpret_cast<const Integer *>(B); \
  for (int B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
    /* Process one row of A at a time.  Doesn't seem to be faster to do multiple rows of A at once.*/ \
    for (int A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
      const Integer *A_row = reinterpret_cast<const Integer*>(A + A_rowidx * width); \
      /* These will be packed 32-bit integers containing sums for each row of B multiplied by the row of A. \
         Iterate over shared (inner) dimension.*/ \
      int k = 0; \
      Integer a = *(A_row + k); \
      Integer sum0 = madd_epi16(a, *(B0_col + k * 8)); \
      Integer sum1 = madd_epi16(a, *(B0_col + k * 8 + 1)); \
      Integer sum2 = madd_epi16(a, *(B0_col + k * 8 + 2)); \
      Integer sum3 = madd_epi16(a, *(B0_col + k * 8 + 3)); \
      Integer sum4 = madd_epi16(a, *(B0_col + k * 8 + 4)); \
      Integer sum5 = madd_epi16(a, *(B0_col + k * 8 + 5)); \
      Integer sum6 = madd_epi16(a, *(B0_col + k * 8 + 6)); \
      Integer sum7 = madd_epi16(a, *(B0_col + k * 8 + 7)); \
      for (int k = 1; k < simd_width; ++k) { \
        Integer a = *(A_row + k); \
        /* Multiply 16-bit, horizontally add to packed 32-bit integers.*/ \
        Integer mult0 = madd_epi16(a, *(B0_col + k * 8)); \
        Integer mult1 = madd_epi16(a, *(B0_col + k * 8 + 1)); \
        Integer mult2 = madd_epi16(a, *(B0_col + k * 8 + 2)); \
        Integer mult3 = madd_epi16(a, *(B0_col + k * 8 + 3)); \
        Integer mult4 = madd_epi16(a, *(B0_col + k * 8 + 4)); \
        Integer mult5 = madd_epi16(a, *(B0_col + k * 8 + 5)); \
        Integer mult6 = madd_epi16(a, *(B0_col + k * 8 + 6)); \
        Integer mult7 = madd_epi16(a, *(B0_col + k * 8 + 7)); \
        /* Sum packed 32-bit integers with danger of overflow.  TODO: accumulate in 64-bit every so often.*/ \
        sum0 = add_epi32(sum0, mult0); \
        sum1 = add_epi32(sum1, mult1); \
        sum2 = add_epi32(sum2, mult2); \
        sum3 = add_epi32(sum3, mult3); \
        sum4 = add_epi32(sum4, mult4); \
        sum5 = add_epi32(sum5, mult5); \
        sum6 = add_epi32(sum6, mult6); \
        sum7 = add_epi32(sum7, mult7); \
      } \
      /* Reduce sums within 128-bit lanes.*/ \
      Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \
      Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
      /*The specific implementation may need to reduce further.*/ \
      auto total = PermuteSummer(pack0123, pack4567); \
      functor(A_rowidx, B_cols, B0_colidx, total); \
    } \
  } \
} \

/* 8-bit matrix multiply used by AVX and AVX2.
 * These have two peculiar properties:
 * 1. The sign instructions don't exist in AVX512.
 * 2. 16 registers means gcc's register allocation failed so I wrote it in my
 *    own asm.
 * 3. They support 3-argument vpsignb and vpmaddubsw.
 *
 * Fun fact: AVX introduced the three-argument vpsignb and vpmaddubsw but only
 * for 128-bit, despite the primary change in AVX being the addition of
 * 256-bit.  We had to wait for AVX2 to get 256-bit versions of vpsignb and
 * vpmaddubsw.  That's why this code is generic over 128-bit or 256-bit.
 */

AVX2 inline static void InnerAVX2(
    __m256i a, const __m256i *b,
    __m256i &sum0, __m256i &sum1, __m256i &sum2, __m256i &sum3,
    __m256i &sum4, __m256i &sum5, __m256i &sum6, __m256i &sum7) {
  // Annoyingly the only 8-bit multiply is signed * unsigned (maddubs).
  // So we take the sign bits off of a and apply them each b in a * b.
  //
  // We have only 16 YMM registers but we want to store:
  // 1 for a (or |a|)
  // 8 temporaries for applying sign to each column of B.
  // 8 sums.
  //
  // gcc's register allocator does:
  // 1 for a, do all the sign application, then overwrite with |a|
  // 8 temporaries
  // 7 sums in registers + 1 on the stack
  //
  // But it's possible to complete an operation early, freeing up its
  // temporary register for reuse.  But completing an operation early
  // requires us to have |a| for vpmaddubsw while completing the later
  // operation needs a again to apply sign.
  //
  // So we do two columns, 0 and 1, early.  This allows b0_b6 and b1_b7
  // to be reused by columns 6 and 7, respectively.  And there's enough
  // registers to store both a and |a|.
  //
  // These are the temporary variables used to process each column of b.
  // We let the compiler choose which register number is which, but force
  // it to allocate all registers.
  __m256i absa;
  __m256i b0_b6, b1_b7, b2, b3, b4, b5;
  // Maybe this will tell gcc that we're accessing 8 registers starting
  // at B_live.  Though I doubt it because we're passing the address as a
  // register.
  typedef struct { __m256i x[8]; } B_range;
  asm(
      // Copy the first 6 columns of b to registers.  We assume B has
      // been rearranged so that these 8 columns are consecutive.
      // vpsignb does not take a memory address as its second argument,
      // so this can't be inlined into vsignb.
      "vmovdqa          (%[B]), %[b0_b6]\n"
      "vmovdqa   %c[size](%[B]), %[b1_b7]\n"
      // These multiplies are executed by the assembler, not by the CPU
      // at run time.
      // I would have liked to just initialize b2 etc above but that
      // would make it an input argument "+x" instead of "=&x".  And +x
      // counts as two operands for purposes of gcc's annoying 30-operand
      // limit.
      "vmovdqa 2*%c[size](%[B]), %[b2]\n"
      "vmovdqa 3*%c[size](%[B]), %[b3]\n"
      "vmovdqa 4*%c[size](%[B]), %[b4]\n"
      "vmovdqa 5*%c[size](%[B]), %[b5]\n"
      // Store the absolute value of a in absa.
      "vpabsb  %[a], %[absa]\n"
      // If a byte of a is negative, negate the corresponding byte in
      // b0_b6 etc.
      "vpsignb %[a], %[b0_b6], %[b0_b6]\n"
      "vpsignb %[a], %[b1_b7], %[b1_b7]\n"
      // Multiply signed * unsigned then horizontally add to form packed
      // 16-bit integers:
      // b0[0] * |a|[0] + b0[1] * |a|[1], b0[2] * |a|[2] + b0[3] * |a|[3], ...
      "vpmaddubsw %[b0_b6], %[absa], %[b0_b6]\n"
      "vpmaddubsw %[b1_b7], %[absa], %[b1_b7]\n"
      // vpmaddubsw has latency 5 so work on some other sign bits while
      // we're at it.
      "vpsignb %[a], %[b2], %[b2]\n"
      "vpsignb %[a], %[b3], %[b3]\n"
      "vpsignb %[a], %[b4], %[b4]\n"
      "vpsignb %[a], %[b5], %[b5]\n"
      // Perform a 16-bit add with saturation to accumlate sums.
      "vpaddsw %[b0_b6], %[sum0], %[sum0]\n"
      // Now we can reuse b0_b6 for b6
      "vmovdqa 6*%c[size](%[B]), %[b0_b6]\n"
      "vpaddsw %[b1_b7], %[sum1], %[sum1]\n"
      // Now we can reuse b1_b7 for b7
      "vmovdqa 7*%c[size](%[B]), %[b1_b7]\n"
      // More crunching while the load happens.
      "vpmaddubsw %[b2], %[absa], %[b2]\n"
      "vpmaddubsw %[b3], %[absa], %[b3]\n"
      "vpmaddubsw %[b4], %[absa], %[b4]\n"
      "vpsignb %[a], %[b0_b6], %[b0_b6]\n"
      "vpsignb %[a], %[b1_b7], %[b1_b7]\n"
      "vpmaddubsw %[b5], %[absa], %[b5]\n"
      "vpmaddubsw %[b0_b6], %[absa], %[b0_b6]\n"
      "vpmaddubsw %[b1_b7], %[absa], %[b1_b7]\n"
      "vpaddsw %[b2], %[sum2], %[sum2]\n"
      "vpaddsw %[b3], %[sum3], %[sum3]\n"
      "vpaddsw %[b4], %[sum4], %[sum4]\n"
      "vpaddsw %[b5], %[sum5], %[sum5]\n"
      "vpaddsw %[b0_b6], %[sum6], %[sum6]\n"
      "vpaddsw %[b1_b7], %[sum7], %[sum7]\n"
      : [sum0] "+x" (sum0),
        [sum1] "+x" (sum1),
        [sum2] "+x" (sum2),
        [sum3] "+x" (sum3),
        [sum4] "+x" (sum4),
        [sum5] "+x" (sum5),
        [sum6] "+x" (sum6),
        [sum7] "+x" (sum7),
        [b0_b6] "=&x" (b0_b6),
        [b1_b7] "=&x" (b1_b7),
        [b2] "=&x" (b2),
        [b3] "=&x" (b3),
        [b4] "=&x" (b4),
        [b5] "=&x" (b5),
        [absa] "=&x" (absa)
      : 
        // I would like to use m here but that non-deterministically
        // chooses %(eax) or -256$(eax) and there's no way to add to that
        // memory address:
        // https://gcc.gnu.org/ml/gcc-help/2011-04/msg00518.html
        //
        [B] "r" (reinterpret_cast<const B_range*>(b)),
        [a] "x" (a),
        [size] "i" (sizeof(__m256i))
    );
}


// For SSSE3 without AVX
SSSE3 inline static void InnerSSSE3(
    __m128i a, const __m128i *b,
    __m128i &sum0, __m128i &sum1, __m128i &sum2, __m128i &sum3,
    __m128i &sum4, __m128i &sum5, __m128i &sum6, __m128i &sum7) {
  __m128i a_positive = abs_epi8(a);
  sum0 = adds_epi16(sum0, maddubs_epi16(a_positive, sign_epi8(b[0], a)));
  sum1 = adds_epi16(sum1, maddubs_epi16(a_positive, sign_epi8(b[1], a)));
  sum2 = adds_epi16(sum2, maddubs_epi16(a_positive, sign_epi8(b[2], a)));
  sum3 = adds_epi16(sum3, maddubs_epi16(a_positive, sign_epi8(b[3], a)));
  sum4 = adds_epi16(sum4, maddubs_epi16(a_positive, sign_epi8(b[4], a)));
  sum5 = adds_epi16(sum5, maddubs_epi16(a_positive, sign_epi8(b[5], a)));
  sum6 = adds_epi16(sum6, maddubs_epi16(a_positive, sign_epi8(b[6], a)));
  sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a)));
}
//AVX2 or SSSE3 multiply
#define MULTIPLY8_define(Integer, target) \
template <class WriteC> target static void Multiply(const int8_t *A, const int8_t *B, WriteC functor, Index A_rows, Index width, Index B_cols) { \
  assert(width % sizeof(Integer) == 0); \
  assert(B_cols % 8 == 0); \
  assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); \
  assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \
  const int simd_width = width / sizeof(Integer); \
  const Integer *B0_col = reinterpret_cast<const Integer*>(B); \
  /*Go over 8 columns of B at a time.*/ \
  for (int B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
    /*Process one row of A at a time.  Doesn't seem to be faster to do multiple rows of A at once.*/ \
    for (int A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
      /*Iterate over shared (inner) dimension.*/ \
      const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width); \
      const Integer *A_end = A_live + simd_width; \
      const Integer *B_live = B0_col; \
      /* Rather than initializing as zeros and adding, just initialize the first.*/ \
      Integer a = *(A_live++); \
      Integer a_positive = abs_epi8(a); \
      /* These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A.*/ \
      Integer sum0 = maddubs_epi16(a_positive, sign_epi8(B_live[0], a)); \
      Integer sum1 = maddubs_epi16(a_positive, sign_epi8(B_live[1], a)); \
      Integer sum2 = maddubs_epi16(a_positive, sign_epi8(B_live[2], a)); \
      Integer sum3 = maddubs_epi16(a_positive, sign_epi8(B_live[3], a)); \
      Integer sum4 = maddubs_epi16(a_positive, sign_epi8(B_live[4], a)); \
      Integer sum5 = maddubs_epi16(a_positive, sign_epi8(B_live[5], a)); \
      Integer sum6 = maddubs_epi16(a_positive, sign_epi8(B_live[6], a)); \
      Integer sum7 = maddubs_epi16(a_positive, sign_epi8(B_live[7], a)); \
      B_live += 8; \
      /* Use A as the loop variable so the add can be done where gcc likes it for branch prediction.*/ \
      for (; A_live != A_end; ++A_live, B_live += 8) { \
        Inner##target(*A_live, B_live, sum0, sum1, sum2, sum3, sum4, sum5, sum6, sum7); \
      } \
      /* Convert 16-bit to 32-bit and add, not caring what parts are added.
       * Implementations:
       * 1. https://github.com/tesseract-ocr/tesseract/blob/master/src/arch/intsimdmatrixavx2.cpp#L67 under Apache license:
       *   This does a multiply by 1 and horizontal add:
       *    _mm512_madd_epi16(sum, _mm512_set1_epi16(1))
       *   Current fastest.
       *
       * 2. Signed extension and fold halves:
       *    sum = _mm512_add_epi32(
       *      _mm512_cvtepi16_epi32(_mm512_castsi512_si256(sum)),
       *      _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(sum, 1)));
       *
       * 3. Sign extend by abuse of bitshift, then add.
       * sum = _mm512_add_epi32(
       *      _mm512_srai_epi32(_mm512_slli_epi32(sum, 16), 16),
       *      _mm512_srai_epi32(sum, 16));
       */ \
      Integer ones = set1_epi16<Integer>(1); \
      sum0 = madd_epi16(sum0, ones); \
      sum1 = madd_epi16(sum1, ones); \
      sum2 = madd_epi16(sum2, ones); \
      sum3 = madd_epi16(sum3, ones); \
      sum4 = madd_epi16(sum4, ones); \
      sum5 = madd_epi16(sum5, ones); \
      sum6 = madd_epi16(sum6, ones); \
      sum7 = madd_epi16(sum7, ones); \
      Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); \
      Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); \
      auto total = PermuteSummer(pack0123, pack4567); \
      /*WriteC(C + A_rowidx * B_cols + B0_colidx, total, unquant_reg);*/ \
      functor(A_rowidx, B_cols, B0_colidx, total); \
    } \
  } \
} \


// Find the maximum absolute value of packed float32s.
/*
template <class Register> inline static float MaxAbsoluteBackend(const float *begin_float, const float *end_float) {
  assert(end_float > begin_float);
  assert((end_float - begin_float) % (sizeof(Register) / sizeof(float)) == 0);
  const Register *begin = reinterpret_cast<const Register*>(begin_float);
  const Register *end = reinterpret_cast<const Register*>(end_float);
  // Get the sign bit.
  union {float f; int32_t i;} float_convert;
  float_convert.i = 0x7fffffff;
  Register and_me = set1_ps<Register>(float_convert.f);
  Register highest = and_ps(and_me, *begin);
  for (++begin; begin != end; ++begin) {
    Register reg = and_ps(and_me, *begin);
    highest = max_ps(highest, reg);
  }

  return MaxFloat32(highest);
}*/
#define MAXABS_DEFINE(Register, target) \
target static float MaxAbsolute(const float *begin_float, const float *end_float) { \
  assert(end_float > begin_float); \
  assert((end_float - begin_float) % (sizeof(Register) / sizeof(float)) == 0); \
  const Register *begin = reinterpret_cast<const Register*>(begin_float); \
  const Register *end = reinterpret_cast<const Register*>(end_float); \
  union {float f; int32_t i;} float_convert; \
  float_convert.i = 0x7fffffff; \
  Register and_me = set1_ps<Register>(float_convert.f); \
  Register highest = and_ps(and_me, *begin); \
  for (++begin; begin != end; ++begin) { \
    Register reg = and_ps(and_me, *begin); \
    highest = max_ps(highest, reg); \
  } \
  return MaxFloat32(highest); \
} \

} // namespace intgemm