Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2018-05-10 14:23:20 +0300
committerKenneth Heafield <github@kheafield.com>2018-05-10 14:23:20 +0300
commit27459bae2926a203e3da071851a838b3cbdfa588 (patch)
tree81bbdef1d8bcb73c999932f7dd92d9301b283418 /AVX_Matrix_Mult.cc
parentabc16f330add699d7fae0536a3be77e5aace6cc0 (diff)
Standardize file suffixes
Diffstat (limited to 'AVX_Matrix_Mult.cc')
-rw-r--r--AVX_Matrix_Mult.cc192
1 files changed, 192 insertions, 0 deletions
diff --git a/AVX_Matrix_Mult.cc b/AVX_Matrix_Mult.cc
new file mode 100644
index 0000000..9ee19fe
--- /dev/null
+++ b/AVX_Matrix_Mult.cc
@@ -0,0 +1,192 @@
+// This is an AVX512F implementation of int16_t multiply based on Jacob
+// Devlin's SSE code. The original SSE code was:
+
+// Copyright (c) 2017 Microsoft Corporation
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+#include "AVX_Matrix_Mult.h"
+
+#include <cassert>
+#include <emmintrin.h>
+#include <immintrin.h>
+#include <math.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <tmmintrin.h>
+#include <xmmintrin.h>
+
+// TODO: Additional improvements can also be made from unrolling the for loop over num_B_rows in SSE_MatrixMult, which is not done here for clarity.
+
+// ***************************************
+// ************** IMPORTANT **************
+// ***************************************
+// The biggest "gotcha" when using this type of multiplication is dealing with overflow related to quantization.
+// It is NOT enough to simply ensure that A and B fit into 16 bit integers. If A and B are quantized with $n$ bits,
+// the result of multiplying them together will be quantized to $n^2$ bits. So if they are near the boundary of the 16-bit
+// mark, then the result will be near 32-bits and overflow. However, if we use, say, n = 10 bits, then the product is 20 bits.
+// This gives us 12 bits left over for the accumulation. So as long as the width of the common dimension is less than 2^12 = 4096, it is
+// *impossible* to overflow. If we used, say, n = 12 bits, then we have 32-(12*2) = 8 bits left over. So we *could* overflow if width > 2^8.
+//
+// So, the tradeoff is between quantization precision and possibility of overflow. A good general value is 10 bits, since this gives high precision
+// (precision is 1/2^10 ~= 0.001, which is more than what's needed for almost all neural nets), and cannot overflow unless the matrix width is > 4096.
+
+// This quantizes floating point values into fixed-point 16-bit integers. Effectively, we are performing an SSE version of
+// float x = ...;
+// int16_t y = (int16_t)(quant_mult*x);
+//
+// Except that the casting is saturated. However, you should always ensure that the input fits into a fixed range anyways.
+// I.e., you should ensure that quant_mult*x fits into the range [-2^15, 2^15].
+// This should always be possible because the value you're quantizing will either be NN weights or NN activations, both of
+// which can be clipped to a fixed range during training.
+void AVX_Quantize(const float * input, __m256i * output, float quant_mult, int size) {
+ assert(size % 16 == 0);
+ assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
+ assert(reinterpret_cast<uintptr_t>(output) % 64 == 0);
+ // Annoyingly, _mm512_packs_epi32 requires AVX512BW which isn't supported
+ // on my target. Therefore I use _mm256_packs_epi32.
+
+ // Fill with the quantization multiplier.
+ const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
+ const float *end = input + size;
+
+ for (; input != end; input += 16, output += 1) {
+ // Load 16 floats
+ __m512 val = _mm512_load_ps(input);
+ // Multiply each by the quantization factor.
+ val = _mm512_mul_ps(val, quant_mult_reg);
+ // Cast to 32-bit int
+ __m512i as_int = _mm512_cvtps_epi32(val);
+ // Pack into 16-bit ints with saturation.
+ // I would do two AVX512 registers and _mm512_packs_epi32 but that's not
+ // AVX515F.
+ *output = _mm256_packs_epi32(_mm512_castsi512_si256(as_int), _mm512_extracti64x4_epi64(as_int, 1));
+ }
+}
+
+namespace {
+
+// Assuming sum1, sum2, sum3, and sum4 are arrays 32-bit signed integers,
+// reduce within each.
+// Returns [sum(sum1), sum(sum2), sum(sum3), sum(sum4)]
+// TODO: consider doing in 64-bit, allowing 4 more bits of quantization?
+inline __m128i Reduce(__m512i sum1, __m512i sum2, __m512i sum3, __m512i sum4) {
+ // 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2
+ __m512i pack12 = _mm512_add_epi32(_mm512_unpackhi_epi32(sum1, sum2), _mm512_unpacklo_epi32(sum1, sum2));
+ // 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4
+ __m512i pack34 = _mm512_add_epi32(_mm512_unpackhi_epi32(sum3, sum4), _mm512_unpacklo_epi32(sum3, sum4));
+ // 1 2 3 4 1 2 3 4 1 2 3 4 1 2 3 4
+ __m512i pack1234 = _mm512_add_epi32(_mm512_unpackhi_epi64(pack12, pack34), _mm512_unpacklo_epi64(pack12, pack34));
+ // Cut the register into halves and sum those. 1 2 3 4 1 2 3 4
+ __m256i halves = _mm256_add_epi32(_mm512_castsi512_si256(pack1234), _mm512_extracti64x4_epi64(pack1234, 1));
+ // Again: cut the register into halves and sum those. 1 2 3 4
+ __m128i ret = _mm_add_epi32(_mm256_castsi256_si128(halves), _mm256_extracti128_si256(halves, 1));
+ return ret;
+}
+
+union FloatAccess {
+ float as_f[4];
+ __m128 as_n;
+};
+
+} // namespace
+
+// We are multiplying A * B^T, as opposed to A * B. This is important because it means we can do consecutive memory access on A * B^T which allows to to take the most
+// advantage of L1 cache.
+//
+// B is typically a weight matrix, so it can be pre-processed offline, and therefore this transpose does not cost anything.
+// A is typically an activation minibatch matrix.
+// A and B must be 64-byte aligned.
+// C should be the usual 4-byte alignment.
+void AVX_MatrixMult(const __m512i * A, const __m512i * B, float * C, float unquant_mult, int num_A_rows, int num_B_rows, int width) {
+ assert(num_A_rows % 4 == 0);
+ assert(width % 32 == 0);
+ assert(reinterpret_cast<uintptr_t>(A) % 64 == 0);
+ assert(reinterpret_cast<uintptr_t>(B) % 64 == 0);
+
+ const __m128 unquant_mult_sse = _mm_set1_ps(unquant_mult);
+
+ const int sse_width = width/32;
+
+ // We do loop unrolling over A. This is *significantly* faster
+ // since B can live in the registers. We are assuming that
+ // A is a multiple of 4, but we can add extra code to handle values of 1, 2, 3.
+ //
+ // We could also do loop unrolling over B, which adds some additional speedup.
+ // We don't do that for the sake of clarity.
+ //
+ // There are other memory access patterns we could do, e.g., put B on the outer loop.
+ // The justification is that A is typically small enough that it can live in L1 cache.
+ // B is usually a larger weight matrix, so it might not be able to. However, we are using
+ // each element of B four times while it's still in a register, so caching is not as important.
+ for (int i = 0; i < num_A_rows; i += 4) {
+ const __m512i * A1_row = A + (i+0)*sse_width;
+ const __m512i * A2_row = A + (i+1)*sse_width;
+ const __m512i * A3_row = A + (i+2)*sse_width;
+ const __m512i * A4_row = A + (i+3)*sse_width;
+
+ for (int j = 0; j < num_B_rows; j++) {
+ const __m512i * B_row = B + j*sse_width;
+
+ __m512i sum1 = _mm512_setzero_si512();
+ __m512i sum2 = _mm512_setzero_si512();
+ __m512i sum3 = _mm512_setzero_si512();
+ __m512i sum4 = _mm512_setzero_si512();
+
+ // This is just a simple dot product, unrolled four ways.
+ for (int k = 0; k < sse_width; k++) {
+ __m512i b = *(B_row + k);
+
+ __m512i a1 = *(A1_row + k);
+ __m512i a2 = *(A2_row + k);
+ __m512i a3 = *(A3_row + k);
+ __m512i a4 = *(A4_row + k);
+
+ // madd_epi16 does multiply add on 8 16-bit integers and accumulates into a four 32-bit register.
+ // E.g.,
+ // a1 = [f1, f2, f3, f4, f5, f6, f7, h8] (16-bit ints)
+ // b1 = [h1, h2, h3, h4, h5, h6, h7, h8] (16-bit ints)
+ // result = [f1*h1 + f2*h2, f3*h3 + f4*h4, f5*h5 + f6*h6, f7*h7 + f8*h8] (32-bit ints)
+ // Then add_epi32 just effectively does a += on these 32-bit integers.
+ sum1 = _mm512_add_epi32(sum1, _mm512_madd_epi16(b, a1));
+ sum2 = _mm512_add_epi32(sum2, _mm512_madd_epi16(b, a2));
+ sum3 = _mm512_add_epi32(sum3, _mm512_madd_epi16(b, a3));
+ sum4 = _mm512_add_epi32(sum4, _mm512_madd_epi16(b, a4));
+ }
+ FloatAccess a;
+ // Get floats for each of the sums to write.
+ a.as_n = _mm_cvtepi32_ps(Reduce(sum1, sum2, sum3, sum4));
+ // Undo quantization scaling.
+ a.as_n = _mm_mul_ps(a.as_n, unquant_mult_sse);
+ // Also note that the memory acceses on C are not consecutive, but this is a tradeoff that we have to make.
+ // We can't have consecutive accesses of A, B, *and* C. But we access A and B a lot more so it makes
+ // sense to do it this way.
+ // Scatter to outputs:
+ *(C + (i+0)*num_B_rows + j) = a.as_f[0];
+ *(C + (i+1)*num_B_rows + j) = a.as_f[1];
+ *(C + (i+2)*num_B_rows + j) = a.as_f[2];
+ *(C + (i+3)*num_B_rows + j) = a.as_f[3];
+ /* Sadly the scatter instruction requires avx512vl
+ * _mm_i32scatter_ps(C + i * num_B_rows + j, num_b_rows_scatter, float_sums, sizeof(float));
+ */
+ }
+ }
+}