Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2018-06-16 13:36:25 +0300
committerKenneth Heafield <github@kheafield.com>2018-06-16 13:36:25 +0300
commitc90b2d8a81978715a30c90636f2a4fc6bcd37aaf (patch)
treedfe1002933bb83676c81dbd7a4a8d63c2b124dee /avx2_gemm.cc
parent74b2f80dd3b3fff9ff4ce068e38dc36c7dca1686 (diff)
Add assembly code version to avx2 for int8
Diffstat (limited to 'avx2_gemm.cc')
-rw-r--r--avx2_gemm.cc141
1 files changed, 133 insertions, 8 deletions
diff --git a/avx2_gemm.cc b/avx2_gemm.cc
index 4c5091f..d46bbb3 100644
--- a/avx2_gemm.cc
+++ b/avx2_gemm.cc
@@ -334,14 +334,6 @@ void MatrixMult8Contrast(const __m256i *A, const __m256i *B, float *C, float unq
for (int k = 0; k < simd_width; ++k) {
// Read in 64 8-bit signed integers from A.
__m256i a = *(A_row + k);
- /* These do the loads from B which is important to do early to hide as
- * much memory latency as possible.
- * It's possible to rearrange B so that these will all be consecutive
- * and benchmarks show that is faster. TODO.
- * Annoyingly the only 8-bit multiply is signed * unsigned (maddubs).
- * So we take the sign bits off of a and apply them each b in a * b.
- * There is a 256-bit sign instruction so we'll try that.
- */
// Negate 8-bit values in b if the corresponding a was negative.
// Negation is implemented by subtraction from zero.
__m256i b0 = _mm256_sign_epi8(*(B0_row + k * 8), a);
@@ -380,6 +372,139 @@ void MatrixMult8Contrast(const __m256i *A, const __m256i *B, float *C, float unq
}
}
+void MatrixMult8ASM(const __m256i *A, const __m256i *B, float *C, float unquant_mult, int num_A_rows, int num_B_cols, int width) {
+ assert(width % 32 == 0);
+ assert(reinterpret_cast<uintptr_t>(A) % 32 == 0);
+ assert(reinterpret_cast<uintptr_t>(B) % 32 == 0);
+ assert(reinterpret_cast<uintptr_t>(C) % 32 == 0);
+ assert(num_B_cols % 8 == 0);
+ __m256 unquant_reg = _mm256_set1_ps(unquant_mult);
+ const int simd_width = width / 32;
+ int B0_colidx = 0;
+ // Go over 8 columns of B at a time.
+ for (const __m256i *B0_col = B; B0_colidx != num_B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+ // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
+ for (int A_rowidx = 0; A_rowidx < num_A_rows; ++A_rowidx) {
+ const __m256i *A_row = A + A_rowidx * simd_width;
+ // These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A.
+ __m256i sum0 = _mm256_setzero_si256();
+ __m256i sum1 = _mm256_setzero_si256();
+ __m256i sum2 = _mm256_setzero_si256();
+ __m256i sum3 = _mm256_setzero_si256();
+ __m256i sum4 = _mm256_setzero_si256();
+ __m256i sum5 = _mm256_setzero_si256();
+ __m256i sum6 = _mm256_setzero_si256();
+ __m256i sum7 = _mm256_setzero_si256();
+ // Iterate over shared (inner) dimension.
+ for (int k = 0; k < simd_width; ++k) {
+ const __m256i *B_base = B0_col + k * 8;
+ // Read in 64 8-bit signed integers from A.
+ __m256i a = *(A_row + k);
+ // The assembly will store the absolute value of a here.
+ __m256i absa;
+ // Annoyingly the only 8-bit multiply is signed * unsigned (maddubs).
+ // So we take the sign bits off of a and apply them each b in a * b.
+ //
+ // We have only 16 YMM registers but we want to store:
+ // 1 for a (or |a|)
+ // 8 temporaries for applying sign to each column of B.
+ // 8 sums.
+ //
+ // gcc's register allocator does:
+ // 1 for a, do all the sign application, then overwrite with |a|
+ // 8 temporaries
+ // 7 sums in registers + 1 on the stack
+ //
+ // But it's possible to complete an operation early, freeing up its
+ // temporary register for reuse. But completing an operation early
+ // requires us to have |a| for vpmaddubsw while completing the later
+ // operation needs a again to apply sign.
+ //
+ // So we do two columns, 0 and 1, early. This allows b0_b6 and b1_b7
+ // to be reused by columns 6 and 7, respectively. And there's enough
+ // registers to store both a and |a|.
+ //
+ // These are the temporary variables used to process each column of b.
+ // We let the compiler choose which register number is which, but force
+ // it to allocate all registers.
+ __m256i b0_b6, b1_b7, b2, b3, b4, b5;
+ asm(
+ // Copy the first 6 columns of b to registers. We assume B has
+ // been rearranged so that these 8 columns are consecutive.
+ // vpsignb does not take a memory address as its second argument,
+ // so this can't be inlined into vsignb.
+ "vmovdqa (%[B]), %[b0_b6];\n"
+ "vmovdqa 32(%[B]), %[b1_b7];\n"
+ "vmovdqa 64(%[B]), %[b2];\n"
+ "vmovdqa 96(%[B]), %[b3];\n"
+ "vmovdqa 128(%[B]), %[b4];\n"
+ "vmovdqa 160(%[B]), %[b5];\n"
+ // Store the absolute value of a in absa.
+ "vpabsb %[a], %[absa];\n"
+ // If a byte of a is negative, negate the corresponding byte in
+ // b0_b6 etc.
+ "vpsignb %[a], %[b0_b6], %[b0_b6];\n"
+ "vpsignb %[a], %[b1_b7], %[b1_b7];\n"
+ // Multiply signed * unsigned then horizontally add to form packed
+ // 16-bit integers:
+ // b0[0] * |a|[0] + b0[1] * |a|[1], b0[2] * |a|[2] + b0[3] * |a|[3], ...
+ "vpmaddubsw %[b0_b6], %[absa], %[b0_b6];\n"
+ "vpmaddubsw %[b1_b7], %[absa], %[b1_b7];\n"
+ // vpmaddubsw has latency 5 so work on some other sign bits while
+ // we're at it.
+ "vpsignb %[a], %[b2], %[b2];\n"
+ "vpsignb %[a], %[b3], %[b3];\n"
+ "vpsignb %[a], %[b4], %[b4];\n"
+ "vpsignb %[a], %[b5], %[b5];\n"
+ // Perform a 16-bit add with saturation to accumlate sums.
+ "vpaddsw %[b0_b6], %[sum0], %[sum0];\n"
+ // Now we can reuse b0_b6 for b6
+ "vmovdqa 192(%[B]), %[b0_b6];\n"
+ "vpaddsw %[b1_b7], %[sum1], %[sum1];\n"
+ // Now we can reuse b1_b7 for b7
+ "vmovdqa 224(%[B]), %[b1_b7];\n"
+ // More crunching while the load happens.
+ "vpmaddubsw %[b2], %[absa], %[b2];\n"
+ "vpmaddubsw %[b3], %[absa], %[b3];\n"
+ "vpmaddubsw %[b4], %[absa], %[b4];\n"
+ "vpsignb %[a], %[b0_b6], %[b0_b6];\n"
+ "vpsignb %[a], %[b1_b7], %[b1_b7];\n"
+ "vpmaddubsw %[b5], %[absa], %[b5];\n"
+ "vpmaddubsw %[b0_b6], %[absa], %[b0_b6];\n"
+ "vpmaddubsw %[b1_b7], %[absa], %[b1_b7];\n"
+ "vpaddsw %[b2], %[sum2], %[sum2];\n"
+ "vpaddsw %[b3], %[sum3], %[sum3];\n"
+ "vpaddsw %[b4], %[sum4], %[sum4];\n"
+ "vpaddsw %[b5], %[sum5], %[sum5];\n"
+ "vpaddsw %[b0_b6], %[sum6], %[sum6];\n"
+ "vpaddsw %[b1_b7], %[sum7], %[sum7];\n"
+ : [sum0] "+x" (sum0),
+ [sum1] "+x" (sum1),
+ [sum2] "+x" (sum2),
+ [sum3] "+x" (sum3),
+ [sum4] "+x" (sum4),
+ [sum5] "+x" (sum5),
+ [sum6] "+x" (sum6),
+ [sum7] "+x" (sum7),
+ [b0_b6] "=&x" (b0_b6),
+ [b1_b7] "=&x" (b1_b7),
+ [b2] "=&x" (b2),
+ [b3] "=&x" (b3),
+ [b4] "=&x" (b4),
+ [b5] "=&x" (b5),
+ [absa] "=&x" (absa)
+ // Tell gcc precisely what we are reading from RAM.
+ : [B] "r" (*reinterpret_cast<const __m256i (*)[8]>(B_base)),
+ [a] "x" (a)
+ );
+ }
+ // Write to C.
+ __m256i combined = Reduce16to32(sum0, sum1, sum2, sum3, sum4, sum5, sum6, sum7);
+ *reinterpret_cast<__m256*>(C + A_rowidx * num_B_cols + B0_colidx) = _mm256_mul_ps(_mm256_cvtepi32_ps(combined), unquant_reg);
+ }
+ }
+}
+
} // namespace AVX2
#endif // __AVX2__