Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2018-06-17 14:33:33 +0300
committerKenneth Heafield <github@kheafield.com>2018-06-17 14:33:33 +0300
commit694c597d9d3572331b7bb16b5ab354653b938792 (patch)
tree3bccfe1e2e6fbfd00204a453a54de448cd4e8119
parent6b48ad03d1fe89767bf150a12cdfd7e73be0078c (diff)
Refactoring: progress on SSE2
Clean up tests
-rw-r--r--Benchmark.cc1
-rw-r--r--Makefile12
-rw-r--r--Quantize.cc140
-rw-r--r--Quantize.h27
-rw-r--r--QuantizeTest.cc128
-rw-r--r--Test.cc38
-rw-r--r--aligned.h22
-rw-r--r--avx2_gemm.cc8
-rw-r--r--interleave.h94
9 files changed, 191 insertions, 279 deletions
diff --git a/Benchmark.cc b/Benchmark.cc
index 7a47b1b..530338c 100644
--- a/Benchmark.cc
+++ b/Benchmark.cc
@@ -1,7 +1,6 @@
#include "avx512_gemm.h"
#include "avx2_gemm.h"
#include "SSE_Matrix_Mult.h"
-#include "Quantize.h"
#include "StopWatch.h"
#include <cassert>
diff --git a/Makefile b/Makefile
index e5ae309..6ea1cb9 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
CXX := g++
-CXXFLAGS := -DNDEBUG -Wall -Werror -fPIC -O3 -march=native
-SRC := avx512_gemm.cc avx2_gemm.cc SSE_Matrix_Mult.cc Quantize.cc StopWatch.cc
+CXXFLAGS := -Wall -Werror -fPIC -O3 -march=native
+SRC := avx512_gemm.cc avx2_gemm.cc sse2_gemm.cc SSE_Matrix_Mult.cc Quantize.cc StopWatch.cc
OBJ := ${SRC:.cc=.o}
all: Test QuantizeTest Benchmark
@@ -14,11 +14,11 @@ Test: ${OBJ} Test.o
Benchmark: ${OBJ} Benchmark.o
${CXX} ${CXXFLAGS} ${OBJ} Benchmark.o -o Benchmark
-QuantizeTest: Quantize.o QuantizeTest.o StopWatch.o avx512_gemm.o avx2_gemm.o
- ${CXX} ${CXXFLAGS} Quantize.o QuantizeTest.o StopWatch.o avx512_gemm.o avx2_gemm.o -o QuantizeTest
+QuantizeTest: QuantizeTest.o StopWatch.o avx512_gemm.o avx2_gemm.o sse2_gemm.o
+ ${CXX} ${CXXFLAGS} QuantizeTest.o avx512_gemm.o avx2_gemm.o sse2_gemm.o -o QuantizeTest
-.c.o: AVX_Matrix_Mult.h
+.c.o:
${CXX} ${CXXFLAGS} -c $<
clean:
- rm -f ${OBJ} QuantizeTest Test
+ rm -f ${OBJ} QuantizeTest Test Test.o QuantizeTest.o Benchmark.o
diff --git a/Quantize.cc b/Quantize.cc
deleted file mode 100644
index 9e704cb..0000000
--- a/Quantize.cc
+++ /dev/null
@@ -1,140 +0,0 @@
-/* Quantize to 8-bit and 16-bit signed integers.
- *
- * 8-bit quantization bans -128 because we can't negate it. The maddubs
- * instructions are unsigned * signed so they require sign bit manipulation.
- *
- * The input and output should be aligned appropriately for instructions:
- * 64 bytes for AVX512
- * 32 bytes for AVX2
- * 16 bytes for SSE
- *
- * The size depends on the function, but it's safe to be a multiple of 32.
- */
-#include "Quantize.h"
-
-#include "Print.h"
-
-#include <cassert>
-#include <emmintrin.h>
-#include <immintrin.h>
-#include <math.h>
-#include <stdint.h>
-#include <tmmintrin.h>
-#include <xmmintrin.h>
-
-namespace intgemm {
-
-#ifdef __SSE2__
-namespace SSE {
-
-/* Uses following instructions:
- * SSE: _mm_mul_ps, _mm_load_ps
- * SSE2: _mm_cvtps_epi32, _mm_packs_epi32, _mm_packs_epi32, _mm_cmpeq_epi8, _mm_sub_epi8
- */
-
-// Same implementation as AVX512, just width.
-inline __m128i QuantizerGrab(const float *input, const __m128 quant_mult_reg) {
- return _mm_cvtps_epi32(_mm_mul_ps(_mm_load_ps(input), quant_mult_reg));
-}
-
-/* I also tried an implementation based on _mm_cvtps_pi16 but it was slower:
- * For size 1048576, run 10x in seconds on i7-6700:
- * This code: 0.00228409, 0.00204906
- * With _mm_cvtps_pi16 basis: 0.00391884, 0.00390869
- */
-void Quantize16(const float *input, int16_t *output, float quant_mult, std::size_t size) {
- assert(size % 8 == 0);
- assert(reinterpret_cast<uintptr_t>(input) % 16 == 0);
- const __m128 quant_mult_reg = _mm_set1_ps(quant_mult);
- const float *end = input + size;
- for (; input != end; input += 8, output += 16) {
- __m128i g0 = QuantizerGrab(input, quant_mult_reg);
- __m128i g1 = QuantizerGrab(input + 4, quant_mult_reg);
- __m128i packed = _mm_packs_epi32(g0, g1);
- *reinterpret_cast<__m128i*>(output) = packed;
- }
-}
-
-void Quantize8(const float *input, int8_t *output, float quant_mult, std::size_t size) {
- assert(size % 16 == 0);
- assert(reinterpret_cast<uintptr_t>(input) % 16 == 0);
- const __m128 quant_mult_reg = _mm_set1_ps(quant_mult);
-// const __m128i neg127 = _mm_set1_epi8(-127);
- const __m128i neg128 = _mm_set1_epi8(-128);
- const float *end = input + size;
- for (; input != end; input += 16, output += 16) {
- __m128i g0 = QuantizerGrab(input, quant_mult_reg);
- __m128i g1 = QuantizerGrab(input + 4, quant_mult_reg);
- __m128i g2 = QuantizerGrab(input + 8, quant_mult_reg);
- __m128i g3 = QuantizerGrab(input + 12, quant_mult_reg);
- __m128i packed0 = _mm_packs_epi32(g0, g1);
- __m128i packed1 = _mm_packs_epi32(g2, g3);
- __m128i packed = _mm_packs_epi16(packed0, packed1);
- /* Ban -128.
- * Don't use the SSE4.1 instruction _mm_max_epi8(packed, neg127). Instead,
- * use SSE2 instructions _mm_cmpeq_epi8 and _mm_sub_epi8.
- * The first generates 0xff for fields -128.
- * The second subtracts 0xff from -128 which has the effect of converting
- * to -127.
- */
- // packed = _mm_max_epi8(packed, neg127);
- __m128i evils = _mm_cmpeq_epi8(packed, neg128);
- packed = _mm_sub_epi8(packed, evils);
- // No permute needed. packs is in order for SSE.
- *reinterpret_cast<__m128i*>(output) = packed;
- }
-}
-
-/* This implementation was much slower.
- * For size 1048576, run 10x in seconds on i7-6700:
- * 0.00134197, 0.0013169 Above implementation.
- * 0.00550692, 0.00568323 Below implementation.
- * However, it does have the advantage of using at most SSE2, whereas the above
- * requires SSE4.1 for _mm_max_epi8.
- */
-/*void Quantize8(const float *input, int8_t *output, float quant_mult, std::size_t size) {
- assert(size % 8 == 0);
- assert(reinterpret_cast<uintptr_t>(input) % 16 == 0);
- const float *end = input + size;
- const __m128 quant_mult_reg = _mm_set1_ps(quant_mult);
- const __m64 neg128 = _mm_set1_pi8(-128);
- for (; input < end; input += 8, output += 8) {
- // These both fill the lower 4 elements with 8-bit integers.
- __m64 second = _mm_cvtps_pi8(_mm_mul_ps(_mm_load_ps(input + 4), quant_mult_reg));
- __m64 first = _mm_cvtps_pi8(_mm_mul_ps(_mm_load_ps(input), quant_mult_reg));
- // Shift second right by 32 bits then or into one register.
- __m64 combined = first | _m_psllqi(second, 32);
- // Test for -128, setting 0xff in corresponding fields.
- __m64 evils = _mm_cmpeq_pi8(combined, neg128);
- // Subtract 0xff from -128s to yield -127.
- combined = _mm_sub_pi8(combined, evils);
- *reinterpret_cast<__m64*>(output) = combined;
- }
-}*/
-
-} // namespace SSE
-#endif // __SSE2__
-
-namespace slow {
-
-void Quantize16(const float *input, int16_t *output, float quant_mult, std::size_t size) {
- for (std::size_t i = 0; i < size; ++i) {
- float value = roundf(input[i] * quant_mult);
- value = std::max(-32768.0f, value);
- value = std::min(32767.0f, value);
- output[i] = value;
- }
-}
-
-void Quantize8(const float *input, int8_t *output, float quant_mult, std::size_t size) {
- for (std::size_t i = 0; i < size; ++i) {
- float value = roundf(input[i] * quant_mult);
- value = std::max(-127.0f, value);
- value = std::min(127.0f, value);
- output[i] = value;
- }
-}
-
-} // namespace slow
-
-} // namespace intgemm
diff --git a/Quantize.h b/Quantize.h
deleted file mode 100644
index b9391bb..0000000
--- a/Quantize.h
+++ /dev/null
@@ -1,27 +0,0 @@
-#pragma once
-
-#include <immintrin.h>
-#include <cstddef>
-
-namespace intgemm {
-
-#ifdef __AVX2__
-namespace AVX2 {
-void Quantize16(const float *input, int16_t *output, float quant_mult, std::size_t size);
-void Quantize8(const float *input, int8_t *output, float quant_mult, std::size_t size);
-} // namespace AVX2
-#endif // __AVX2__
-
-#ifdef __SSE2__
-namespace SSE {
-void Quantize16(const float *input, int16_t *output, float quant_mult, std::size_t size);
-void Quantize8(const float *input, int8_t *output, float quant_mult, std::size_t size);
-} // namespace SSE
-#endif // __SSE2__
-
-namespace slow {
-void Quantize16(const float *input, int16_t *output, float quant_mult, std::size_t size);
-void Quantize8(const float *input, int8_t *output, float quant_mult, std::size_t size);
-} // namespace slow
-
-} // namespace intgemm
diff --git a/QuantizeTest.cc b/QuantizeTest.cc
index 3c64be1..f6ed2e6 100644
--- a/QuantizeTest.cc
+++ b/QuantizeTest.cc
@@ -1,6 +1,7 @@
-#include "Quantize.h"
#include "avx512_gemm.h"
-#include "StopWatch.h"
+#include "avx2_gemm.h"
+#include "sse2_gemm.h"
+#include "aligned.h"
#include <cstring>
#include <math.h>
@@ -10,6 +11,24 @@
namespace intgemm {
namespace {
+void QuantizeRef(const float *input, int16_t *output, float quant_mult, std::size_t size) {
+ for (std::size_t i = 0; i < size; ++i) {
+ float value = roundf(input[i] * quant_mult);
+ value = std::max(-32768.0f, value);
+ value = std::min(32767.0f, value);
+ output[i] = value;
+ }
+}
+
+void QuantizeRef(const float *input, int8_t *output, float quant_mult, std::size_t size) {
+ for (std::size_t i = 0; i < size; ++i) {
+ float value = roundf(input[i] * quant_mult);
+ value = std::max(-127.0f, value);
+ value = std::min(127.0f, value);
+ output[i] = value;
+ }
+}
+
template <class I> bool IsOff(float from, I ref, I test) {
if (ref == test) return false;
if (ref - test > 1 && test - ref > 1) return true;
@@ -20,101 +39,46 @@ template <class I> bool IsOff(float from, I ref, I test) {
return true;
}
-bool Test(const float *input_unaligned, float quant_mult, std::size_t size) {
+template <class Backend> bool Test(const float *input_unaligned, float quant_mult, std::size_t size) {
+ typedef typename Backend::Integer Integer;
bool success = true;
- float *input = static_cast<float*>(aligned_alloc(64, sizeof(float) * size));
- std::memcpy(input, input_unaligned, sizeof(float) * size);
- void *mem = aligned_alloc(64, sizeof(int16_t) * size * 2);
- int16_t *ref16 = static_cast<int16_t*>(mem);
- int16_t *test16 = ref16 + size;
- slow::Quantize16(input, ref16, quant_mult, size);
- AVX2::Quantize16(input, test16, quant_mult, size);
- for (std::size_t i = 0; i < size; ++i) {
- if (IsOff(input[i] * quant_mult, ref16[i], test16[i])) {
- std::cerr << "16-bit error at " << i << " from " << input[i] << '*' << quant_mult << '=' << (input[i]*quant_mult) << " ref = " << ref16[i] << " test = " << test16[i] << '\n';
- success = false;
- }
- }
+ free_ptr<float> input(AlignedArray<float>(size));
+ std::memcpy(input.get(), input_unaligned, sizeof(float) * size);
- int8_t *ref8 = static_cast<int8_t*>(mem);
- int8_t *test8 = ref8 + size;
- slow::Quantize8(input, ref8, quant_mult, size);
- AVX2::Quantize8(input, test8, quant_mult, size);
+ free_ptr<Integer> ref(AlignedArray<Integer>(size));
+ free_ptr<Integer> test(AlignedArray<Integer>(size));
+ QuantizeRef(input.get(), ref.get(), quant_mult, size);
+ Backend::Quantize(input.get(), test.get(), quant_mult, size);
for (std::size_t i = 0; i < size; ++i) {
- if (IsOff(input[i] * quant_mult, ref8[i], test8[i])) {
- std::cerr << "8-bit error at " << i << " from " << input[i] << '*' << quant_mult << "=" << (input[i]*quant_mult) << " ref = " << (int16_t)ref8[i] << " test = " << (int16_t)test8[i] << '\n';
+ if (IsOff(input.get()[i] * quant_mult, ref.get()[i], test.get()[i])) {
+ std::cerr << "Error at " << i << " from " << input.get()[i] << '*' << quant_mult << '=' << (input.get()[i]*quant_mult) << " ref = " << ref.get()[i] << " test = " << test.get()[i] << '\n';
success = false;
}
}
-
- free(input);
- free(mem);
return success;
}
-void Benchmark(std::size_t size) {
- float *input = (float*)aligned_alloc(64, sizeof(float) * size);
- void *output = aligned_alloc(64, sizeof(int16_t) * size);
- int8_t *out8 = (int8_t*)output;
- int16_t *out16 = (int16_t*)output;
- for (std::size_t i = 0; i < size; ++i) {
- input[i] = i;
- }
-#ifdef __AVX512F__
- // Burn in.
- slow::Quantize16(input, out16, 3, size);
- {
- StopWatch w("AVX512 16-bit");
- for (int i = 0; i < 10; ++i)
- AVX512::Quantize16(input, out16, 3, size);
- }
-#endif
- slow::Quantize16(input, out16, 3, size);
- {
- StopWatch w("AVX2 16-bit");
- for (int i = 0; i < 10; ++i)
- AVX2::Quantize16(input, out16, 3, size);
- }
- slow::Quantize16(input, out16, 3, size);
- {
- StopWatch w("SSE 16-bit");
- for (int i = 0; i < 10; ++i)
- SSE::Quantize16(input, out16, 3, size);
- }
-#ifdef __AVX512F__
- slow::Quantize8(input, out8, 3, size);
- {
- StopWatch w("AVX512 8-bit");
- for (int i = 0; i < 10; ++i)
- AVX512::Quantize8(input, out8, 3, size);
- }
-#endif
- slow::Quantize8(input, out8, 3, size);
- {
- StopWatch w("AVX2 8-bit");
- for (int i = 0; i < 10; ++i)
- AVX2::Quantize8(input, out8, 3, size);
- }
- slow::Quantize8(input, out8, 3, size);
- {
- StopWatch w("SSE 8-bit");
- for (int i = 0; i < 10; ++i)
- SSE::Quantize8(input, out8, 3, size);
- }
+template <class Backend> bool TestMany() {
+ bool success = true;
+ float input[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
+ success &= Test<Backend>(input, 1.0, 32);
+ success &= Test<Backend>(input, 32.0, 32);
+ float corners[32] = {-32769, -32768, -32767, -129, -128, -127, -1, 0, 1, 126, 127, 128, 129, 32766, 32768, 32769, -1.9, -1.5, -1.1, -1, -0.9, -0.5, -0.1, 0.0, 0.1, 0.5, 0.9, 1.0, 1.1, 1.5, 1.9, 16056.8};
+ success &= Test<Backend>(corners, 1.0, sizeof(corners) / sizeof(float));
+ success &= Test<Backend>(corners, -1.0, sizeof(corners) / sizeof(float));
+ success &= Test<Backend>(corners, -0.49, sizeof(corners) / sizeof(float));
+ return success;
}
} // namespace
} // namespace intgemm
int main() {
- float input[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
+ using namespace intgemm;
bool success = true;
- success &= intgemm::Test(input, 1.0, 32);
- success &= intgemm::Test(input, 32.0, 32);
- float corners[32] = {-32769, -32768, -32767, -129, -128, -127, -1, 0, 1, 126, 127, 128, 129, 32766, 32768, 32769, -1.9, -1.5, -1.1, -1, -0.9, -0.5, -0.1, 0.0, 0.1, 0.5, 0.9, 1.0, 1.1, 1.5, 1.9, 16056.8};
- success &= intgemm::Test(corners, 1.0, sizeof(corners) / sizeof(float));
- success &= intgemm::Test(corners, -1.0, sizeof(corners) / sizeof(float));
- success &= intgemm::Test(corners, -0.49, sizeof(corners) / sizeof(float));
- intgemm::Benchmark(1048576);
+ success &= TestMany<AVX2_8bit>();
+ success &= TestMany<AVX2_16bit>();
+ success &= TestMany<SSE2_8bit>();
+ success &= TestMany<SSE2_16bit>();
return success ? 0 : 1;
}
diff --git a/Test.cc b/Test.cc
index 788d897..338a210 100644
--- a/Test.cc
+++ b/Test.cc
@@ -22,8 +22,9 @@
#include "avx512_gemm.h"
#include "avx2_gemm.h"
-#include "SSE_Matrix_Mult.h"
-#include "Quantize.h"
+#include "sse2_gemm.h"
+#include "aligned.h"
+#include "interleave.h"
#include "StopWatch.h"
#include <cassert>
@@ -37,18 +38,6 @@
namespace intgemm {
-struct DeleteWithFree {
- template <class T> void operator() (T *t) const {
- std::free(const_cast<std::remove_const_t<T>* >(t));
- }
-};
-template <class T> using free_ptr = std::unique_ptr<T, DeleteWithFree>;
-
-// Return memory suitably aligned for SIMD.
-template <class T> T* AlignedArray(std::size_t size) {
- return static_cast<T*>(aligned_alloc(64, size * sizeof(T)));
-}
-
// Rearrange a tile of simd x unroll entries.
template <class V> void SlowRearrangeTile(const V *from, V *to, int simd, int unroll, int cols) {
for (int i = 0; i < unroll; ++i) {
@@ -75,6 +64,25 @@ template <class V> void SlowTranspose(const V *from, V *to, int rows, int cols)
}
}
+void TestTranspose() {
+ free_ptr<int16_t> input(AlignedArray<int16_t>(8 * 8));
+ for (int16_t i = 0; i < 64; ++i) {
+ input.get()[i] = i;
+ }
+ free_ptr<int16_t> ref(AlignedArray<int16_t>(8 * 8));
+ SlowTranspose(input.get(), ref.get(), 8, 8);
+
+ // Overwrite input.
+ __m128i *t = reinterpret_cast<__m128i*>(input.get());
+ Transpose16InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7]);
+
+ for (int16_t i = 0; i < 64; ++i) {
+ if (ref.get()[i] != input.get()[i]) {
+ std::cerr << "Transpose failure at " << i << ": " << ref.get()[i] << " != " << input.get()[i] << '\n';
+ }
+ }
+}
+
template <class Routine> void TestPrepare(int rows = 32, int cols = 16) {
// Create array.
free_ptr<float> input(AlignedArray<float>(rows * cols));
@@ -191,8 +199,10 @@ void TestBoth(int A_rows, int width, int B_cols) {
int main(int argc, char ** argv) {
std::srand(45678);
using namespace intgemm;
+ TestTranspose();
TestPrepare<AVX2_8bit>(64, 32);
TestPrepare<AVX2_16bit>(64, 32);
+ TestPrepare<SSE2_16bit>(8, 8);
// Top matrix sizes from Marian
TestBoth(8, 256, 256);
TestBoth(8, 2048, 256);
diff --git a/aligned.h b/aligned.h
new file mode 100644
index 0000000..44889d1
--- /dev/null
+++ b/aligned.h
@@ -0,0 +1,22 @@
+#pragma once
+
+// Define allocation like:
+// free_ptr<Integer> quantized(AlignedArray<Integer>(rows * cols));
+// This is only used by tests.
+
+#include <memory>
+
+namespace intgemm {
+
+struct DeleteWithFree {
+ template <class T> void operator() (T *t) const {
+ std::free(const_cast<std::remove_const_t<T>* >(t));
+ }
+};
+template <class T> using free_ptr = std::unique_ptr<T, DeleteWithFree>;
+// Return memory suitably aligned for SIMD.
+template <class T> T* AlignedArray(std::size_t size) {
+ return static_cast<T*>(aligned_alloc(64, size * sizeof(T)));
+}
+
+} // namespace intgemm
diff --git a/avx2_gemm.cc b/avx2_gemm.cc
index 63da5b8..56d00e4 100644
--- a/avx2_gemm.cc
+++ b/avx2_gemm.cc
@@ -116,14 +116,6 @@ void AVX2_8bit::Quantize(const float *input, int8_t *output, float quant_mult, i
// ... ...
namespace {
-// Input: 8-bit integers
-// first f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 f16 f17 f18 f19 f20 f21 f22 f23 f24 f25 f26 f27 f28 f29 f30 f31
-// second s0 s1 s2 s3 s4 s5 s6 s7 s8 s9 s10 s11 s12 s13 s14 s15 s16 s17 s18 s19 s20 s21 s22 s23 s24 s25 s26 s27 s28 s29 s30 s31
-// Output:
-// first [f0 s0 f1 s1 f2 s2 f3 s3 f4 s4 f5 s5 f6 s6 f7 s7] [f16 s16 f17 s17 f18 s18 f19 s19 f20 s20 f21 s21 f22 s22 f23 s23]
-// second [f8 s8 f9 s9 f10 s10 f11 s11 f12 s12 f13 s13 f14 s14 f15 s15] [f24 s24 f25 s25 f26 s26 f27 s27 f28 s28 f29 s29 f30 s30 f31 s31]
-INTGEMM_INTERLEAVE(__m256i, 256)
-
inline void ReshapeToEights16(const float *input, __m256 quant_mult_reg, int cols, __m256i &out0, __m256i &out1, __m256i &out2, __m256i &out3) {
out0 = QuantizeTile16(input, input + 8 * cols, quant_mult_reg);
out2 = QuantizeTile16(input + 1 * cols, input + 9 * cols, quant_mult_reg);
diff --git a/interleave.h b/interleave.h
index 7639658..3df57be 100644
--- a/interleave.h
+++ b/interleave.h
@@ -1,5 +1,12 @@
#pragma once
+#include <emmintrin.h>
+#include <immintrin.h>
+#include <tmmintrin.h>
+#include <xmmintrin.h>
+
+namespace intgemm {
+
/* This macro defines functions that interleave their arguments like
* inline void Interleave8(__m256i &first, __m256i &second) {
* __m256i temp = _mm256_unpacklo_epi8(first, second);
@@ -12,7 +19,6 @@
* INTGEMM_INTERLEAVE(__m256i, 256)
* INTGEMM_INTERLEAVE(__m512i, 512)
*/
-
#define INTGEMM_INTERLEAVE(type, prefix) \
inline void Interleave8(type &first, type &second) { \
type temp = _mm##prefix##_unpacklo_epi8(first, second); \
@@ -35,3 +41,89 @@ inline void Interleave64(type &first, type &second) { \
first = temp; \
}
+#ifdef __SSE2__
+INTGEMM_INTERLEAVE(__m128i, )
+#endif
+#ifdef __AVX2__
+INTGEMM_INTERLEAVE(__m256i, 256)
+#endif
+#ifdef __AVX512__
+INTGEMM_INTERLEAVE(__m512i, 512)
+#endif
+
+/* Transpose registers containing 8 packed 16-bit integers.
+ * Each 128-bit lane is handled independently.
+ */
+template <class Register> inline void Transpose16InLane(Register &r0, Register &r1, Register &r2, Register &r3, Register &r4, Register &r5, Register &r6, Register &r7) {
+ // r0: columns 0 1 2 3 4 5 6 7 from row 0
+ // r1: columns 0 1 2 3 4 5 6 7 from row 1
+
+ Interleave16(r0, r1);
+ Interleave16(r2, r3);
+ Interleave16(r4, r5);
+ Interleave16(r6, r7);
+ // r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1
+ // r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1
+ // r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3
+ // r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3
+ // r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5
+ // r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5
+ // r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7
+ // r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7
+
+ Interleave32(r0, r2);
+ Interleave32(r1, r3);
+ Interleave32(r4, r6);
+ Interleave32(r5, r7);
+ // r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3
+ // r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3
+ // r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3
+ // r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3
+ // r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7
+ // r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7
+ // r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7
+ // r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7
+
+ Interleave64(r0, r4);
+ Interleave64(r1, r5);
+ Interleave64(r2, r6);
+ Interleave64(r3, r7);
+ // r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7
+ // r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7
+ // r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7
+ // r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7
+ // r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7
+ // r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7
+
+ // Empirically gcc is able to remove these movs and just rename the outputs of Interleave64.
+ // Swap r1 and r4
+ Register tmp = r4;
+ r4 = r1;
+ r1 = tmp;
+ // Swap r3 and r6.
+ tmp = r3;
+ r3 = r6;
+ r6 = tmp;
+}
+
+/* Tranpose registers containing 16 packed 8-bit integers.
+ * Each 128-bit lane is handled independently.
+ */
+template <class Register> inline void Transpose8InLane(
+ Register &r0, Register &r1, Register &r2, Register &r3, Register &r4, Register &r5, Register &r6, Register &r7,
+ Register &r8, Register &r9, Register &r10, Register &r11, Register &r12, Register &r13, Register &r14, Register &r15) {
+ // Get 8-bit values to 16-bit values so they can travel together.
+ Interleave8(r0, r1);
+ // r0: columns 0 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 from rows 0 and 1.
+ Interleave8(r2, r3);
+ Interleave8(r4, r5);
+ Interleave8(r6, r7);
+ Interleave8(r8, r9);
+ Interleave8(r10, r11);
+ Interleave8(r12, r13);
+ Interleave8(r14, r15);
+ Transpose16InLane(r0, r2, r4, r6, r8, r10, r12, r14);
+ Transpose16InLane(r1, r3, r5, r7, r9, r11, r13, r15);
+}
+
+} // namespace intgemm