Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-04 01:01:15 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-04 01:01:15 +0300
commit57fe315ab631fb44c909f90868584536d05ec0ce (patch)
tree329693c7ee39c8f40bf5c680d70ef10b38d0512f
parent9af1ca24ea9a427771f29dd2ab4f7e9a889bf934 (diff)
Reduce working for SSE2 and AVX2, working on AVX512
-rw-r--r--test/tile_test.inl61
-rw-r--r--tile/reduce.h12
-rw-r--r--tile/reduce.inl125
3 files changed, 137 insertions, 61 deletions
diff --git a/test/tile_test.inl b/test/tile_test.inl
index 98a1300..c612ace 100644
--- a/test/tile_test.inl
+++ b/test/tile_test.inl
@@ -57,43 +57,38 @@ TEST_CASE("Basic Tile " INTGEMM_TEST_NAME, "[tile]") {
}
}
-INTGEMM_TARGET void DumpRegister(Register reg) {
- int32_t values[sizeof(Register) / sizeof(int32_t)];
- memcpy(values, &reg, sizeof(Register));
- for (std::size_t i = 0; i < sizeof(Register) / sizeof(int32_t); ++i) {
- std::cout.width(11);
- std::cout << values[i] << ' ';
- }
-}
-
-INTGEMM_TARGET void Pack32Test() {
- const std::size_t kPack = sizeof(Register) / sizeof(int32_t);
- Register regs[kPack];
- std::mt19937 gen;
- //std::uniform_int_distribution<int32_t> dist(std::numeric_limits<int32_t>::min(), std::numeric_limits<int32_t>::max());
- std::uniform_int_distribution<int32_t> dist(0, 100);
- std::vector<int32_t> reference(kPack, 0);
- for (std::size_t i = 0; i < kPack; ++i) {
- int32_t temp[kPack];
- for (std::size_t j = 0; j < kPack; ++j) {
- temp[j] = dist(gen);
- reference[j] += temp[j];
+struct Pack32Test {
+ template <typename Iterator> INTGEMM_TARGET static void body() {
+ constexpr Index Valid = Iterator::template I<0>();
+ // A zero-length array is a compiler error, so force it to be longer.
+ constexpr Index ArrayLen = Valid ? Valid : 1;
+ const std::size_t kPack = sizeof(Register) / sizeof(int32_t);
+ Register regs[ArrayLen];
+ std::mt19937 gen;
+ std::uniform_int_distribution<int32_t> dist(std::numeric_limits<int32_t>::min(), std::numeric_limits<int32_t>::max());
+ int32_t reference[ArrayLen];
+ memset(reference, 0, sizeof(reference));
+ for (Index i = 0; i < Valid; ++i) {
+ int32_t temp[kPack];
+ for (std::size_t j = 0; j < kPack; ++j) {
+ temp[j] = dist(gen);
+ reference[i] += temp[j];
+ }
+ memcpy(&regs[i], temp, sizeof(Register));
+ }
+ // Decay type for template.
+ Register *indirect = regs;
+ Pack32<Valid, Sum32Op>(indirect);
+ const int32_t *test = reinterpret_cast<const int32_t*>(regs);
+ for (Index i = 0; i < Valid; ++i) {
+ CHECK(test[i] == reference[i]);
}
- memcpy(&regs[i], temp, sizeof(Register));
- }
- Register *indirect = regs;
- for (std::size_t i = 0; i < 4; ++i) {
- DumpRegister(indirect[i]);
- std::cout << '\n';
}
- Pack32<3, Sum32Op>(indirect);
- DumpRegister(indirect[0]);
- std::cout << '\n';
-}
+};
TEST_CASE("Reduce " INTGEMM_TEST_NAME, "[tile]") {
- if (kCPU >= CPUType::INTGEMM_ARCH)
- Pack32Test();
+ if (kCPU < CPUType::INTGEMM_ARCH) return;
+ StaticLoop<Pack32Test, MakeStaticLoopIterator<33>>();
}
} // namespace INTGEMM_ARCH
diff --git a/tile/reduce.h b/tile/reduce.h
index 0ab4e0d..51cf719 100644
--- a/tile/reduce.h
+++ b/tile/reduce.h
@@ -24,19 +24,19 @@ struct Sum32Op {
} // namespace intgemm
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
-#define INTGEMM_THIS_IS_AVX512BW
+#define INTGEMM_THIS_IS_SSE2
#include "reduce.inl"
-#undef INTGEMM_THIS_IS_AVX512BW
-#endif
+#undef INTGEMM_THIS_IS_SSE2
#define INTGEMM_THIS_IS_AVX2
#include "reduce.inl"
#undef INTGEMM_THIS_IS_AVX2
-#define INTGEMM_THIS_IS_SSE2
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+#define INTGEMM_THIS_IS_AVX512BW
#include "reduce.inl"
-#undef INTGEMM_THIS_IS_SSE2
+#undef INTGEMM_THIS_IS_AVX512BW
+#endif
namespace intgemm {
diff --git a/tile/reduce.inl b/tile/reduce.inl
index 3bfa899..c3f7470 100644
--- a/tile/reduce.inl
+++ b/tile/reduce.inl
@@ -14,42 +14,123 @@
namespace intgemm {
namespace INTGEMM_ARCH {
-template <class Op> struct Pack64Even {
+struct RegisterPair { Register hi; Register lo; };
+
+template <class Op, class Folder> struct PackEvens {
template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) {
const Index i = Iterator::template I<0>();
- Register hi = unpackhi_epi64(regs[2 * i], regs[2 * i + 1]);
- Register lo = unpacklo_epi64(regs[2 * i], regs[2 * i + 1]);
- regs[i] = Op::Run(hi, lo);
+ RegisterPair ret = Folder::Even(regs[2 * i], regs[2 * i + 1]);
+ regs[i] = Op::Run(ret.hi, ret.lo);
}
};
-template <Index Valid, class Op> INTGEMM_TARGET static inline void Pack64(Register *regs) {
- StaticLoop<Pack64Even<Op>, MakeStaticLoopIterator<Valid / 2>>(regs);
+
+template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void GenericPack(Register *regs) {
+ StaticLoop<PackEvens<Op, Folder>, MakeStaticLoopIterator<Valid / 2>>(regs);
if (Valid & 1) {
- // For the odd case, shuffle to form 0 g where g is garbage and 0 is accumlated.
- Register shuffled = shuffle_epi32(regs[Valid - 1], 0xB0 /* CDAA */);
- regs[Valid / 2] = Op::Run(shuffled, regs[Valid - 1]);
+ auto values = Folder::Odd(regs[Valid - 1]);
+ regs[Valid / 2] = Folder::OddUpcast(Op::Run(values.lo, values.hi));
}
- // Now [0, (Valid + 1) / 2) contains registers to pack with 128-bit interleaving.
}
-template <class Op> struct Pack32Even {
+struct Pack32Folder {
+ INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) {
+ return RegisterPair { unpackhi_epi32(first, second), unpacklo_epi32(first, second) };
+ }
+ INTGEMM_TARGET static inline RegisterPair Odd(Register reg) {
+ // For the odd case, shuffle to form 0 g 0 g where g is garbage and 0 is accumlated.
+ return RegisterPair { reg, shuffle_epi32(reg, 0x31) };
+ }
+ INTGEMM_TARGET static inline Register OddUpcast(Register reg) { return reg; }
+};
+
+struct Pack64Folder {
+ INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) {
+ return RegisterPair { unpackhi_epi64(first, second), unpacklo_epi64(first, second) };
+ }
+ INTGEMM_TARGET static inline RegisterPair Odd(Register reg) {
+ // For the odd case, shuffle to form 0 g where g is garbage and 0 is accumlated.
+ return RegisterPair { reg, shuffle_epi32(reg, 3 * 4 + 2) };
+ }
+ INTGEMM_TARGET static inline Register OddUpcast(Register reg) { return reg; }
+};
+
+#ifdef INTGEMM_THIS_IS_AVX2
+struct Pack128Folder {
+ INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) {
+ return RegisterPair {
+ // This instruction generates 0s 1s 2s 3s 4f 5f 6f 7f
+ _mm256_permute2f128_si256(first, second, 0x21),
+ // This instruction generates 0f 1f 2f 3f 4s 5s 6s 7s
+ _mm256_blend_epi32(first, second, 0xf0)
+ };
+ }
+ INTGEMM_TARGET static inline SSE2::RegisterPair Odd(Register reg) {
+ return SSE2::RegisterPair { _mm256_extracti128_si256(reg, 1), _mm256_castsi256_si128(reg) };
+ }
+ INTGEMM_TARGET static inline Register OddUpcast(SSE2::Register reg) { return _mm256_castsi128_si256(reg); }
+};
+#endif
+
+#ifdef INTGEMM_THIS_IS_AVX512BW
+struct Pack128Folder {
+ INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) {
+ // TODO can this be optimized with a blend and a shuffle instruction?
+ return RegisterPair {
+ // Form [0th 128-bit of first, 0th 128-bit second, 2nd 128-bit of first, 2nd 128-bit of second]
+ _mm512_mask_permutex_epi64(first, 0xcc, second, (0 << 4) | (1 << 6)),
+ // Form [1st 128-bit of first, 1st 128-bit of second, 3rd 128-bit of first, 3rd 128-bit of second]
+ _mm512_mask_permutex_epi64(second, 0x33, first, 2 | (3 << 2))
+ };
+ }
+ INTGEMM_TARGET static inline AVX2::RegisterPair Odd(Register reg) {
+ return AVX2::RegisterPair { _mm512_castsi512_si256(reg), _mm512_extracti64x4_epi64(reg, 1) };
+ }
+ INTGEMM_TARGET static inline Register OddUpcast(AVX2::Register reg) { return _mm512_castsi256_si512(reg); }
+};
+
+struct Pack256Folder {
+ INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) {
+ return RegisterPair {
+ // This instruction generates first[2] first[3] second[0] second[1]
+ _mm512_shuffle_i64x2(first, second, 2 | (3 << 2) | (0 << 4) | (1 << 6)),
+ // This instruction generates first[0] first[1] second[2] second[3]
+ _mm512_mask_blend_epi64(0xf0, first, second)
+ };
+ }
+ INTGEMM_TARGET static inline AVX2::RegisterPair Odd(Register reg) {
+ return AVX2::RegisterPair { _mm512_castsi512_si256(reg), _mm512_extracti64x4_epi64(reg, 1) };
+ }
+ INTGEMM_TARGET static inline Register OddUpcast(AVX2::Register reg) { return _mm512_castsi256_si512(reg); }
+};
+
+template <class Op> struct PackFours {
+ // Collapse 4 AVX512 registers at once, interleaving 128-bit fields.
template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) {
const Index i = Iterator::template I<0>();
- Register hi = unpackhi_epi32(regs[2 * i], regs[2 * i + 1]);
- Register lo = unpacklo_epi32(regs[2 * i], regs[2 * i + 1]);
- regs[i] = Op::Run(hi, lo);
+ const Register *in = regs + i * 4;
+ // Do 256-bit interleaving first because it's slightly cheaper.
+ RegisterPair mix0pair = Pack256Folder::Even(in[0], in[2]);
+ RegisterPair mix1pair = Pack256Folder::Even(in[1], in[3]);
+ Register mix0 = Op::Run(mix0pair.hi, mix0pair.lo);
+ Register mix1 = Op::Run(mix1pair.hi, mix1pair.lo);
+ mix0pair = Pack128Folder::Even(mix0, mix1);
+ regs[i] = Op::Run(mix0pair.hi, mix0pair.lo);
}
};
+
+#endif
+
template <Index Valid, class Op> INTGEMM_TARGET static inline void Pack32(Register *regs) {
- StaticLoop<Pack32Even<Op>, MakeStaticLoopIterator<Valid / 2>>(regs);
- if (Valid & 1) {
- // For the odd case, shuffle to form 0 g 0 g where g is garbage and 0 is accumlated.
- Register shuffled = shuffle_epi32(regs[Valid - 1], 0x4C /* BADA */);
- regs[Valid / 2] = Op::Run(shuffled, regs[Valid - 1]);
- }
- // Now [0, (Valid + 1) / 2) contains registers to pack with 64-bit interleaving.
- Pack64<(Valid + 1) / 2, Op>(regs);
+ GenericPack<Valid, Op, Pack32Folder>(regs);
+ GenericPack<(Valid + 1) / 2, Op, Pack64Folder>(regs);
+ // SSE2 is done.
+#if defined(INTGEMM_THIS_IS_AVX2)
+ GenericPack<(Valid + 3) / 4, Op, Pack128Folder>(regs);
+#elif defined(INTGEMM_THIS_IS_AVX512BW)
+ StaticLoop<PackFours<Op>, MakeStaticLoopIterator<(Valid / 4)>>(regs);
+ // TODO: non-multiples of 4 registers.
+#endif
}
} // namespace INTGEMM_ARCH