Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-23 02:28:01 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-23 02:28:01 +0300
commite12a54041950569742bac71167f8d8718e5932c1 (patch)
treedfb7f9aa9c4973143f2ed550bf482a748d9dc925
parent47d54971d625711d6292da251ebd904e98ebd43f (diff)
Insane implementation of most cases for writing C. Still missing offset scatter.
-rw-r--r--test/tile_test.inl6
-rw-r--r--tile/access.h73
2 files changed, 73 insertions, 6 deletions
diff --git a/test/tile_test.inl b/test/tile_test.inl
index cb7b155..2a6c059 100644
--- a/test/tile_test.inl
+++ b/test/tile_test.inl
@@ -269,10 +269,14 @@ TEMPLATE_TEST_CASE("MultiplyNoOverhang Unrolled Signed8 " INTGEMM_TEST_NAME, "[t
(UnrollKernel<1, 1, 31, Signed8>),
(UnrollKernel<1, 1, 32, Signed8>),
(UnrollKernel<2, 1, 1, Signed8>),
+ (UnrollKernel<2, 1, 2, Signed8>),
(UnrollKernel<3, 1, 1, Signed8>),
+ (UnrollKernel<3, 1, 3, Signed8>),
(UnrollKernel<4, 1, 1, Signed8>),
(UnrollKernel<5, 1, 1, Signed8>),
- (UnrollKernel<17, 1, 1, Signed8>)
+ (UnrollKernel<15, 1, 1, Signed8>),
+ (UnrollKernel<16, 1, 1, Signed8>)
+// (UnrollKernel<17, 1, 1, Signed8>)
) {
if (kCPU < CPUType::INTGEMM_ARCH) return;
TestMultiplyNoOverhangShapes<TestType>();
diff --git a/tile/access.h b/tile/access.h
index d5e6119..0279392 100644
--- a/tile/access.h
+++ b/tile/access.h
@@ -1,5 +1,7 @@
#pragma once
+#include <type_traits>
+
#include "../types.h"
namespace intgemm {
@@ -21,11 +23,73 @@ template <class T> class RowMajorAccess {
Content &Front() { return *data_; }
// TODO: SLOW. This is here for testing.
- template <Index A_rows, Index B_cols> void Write(const __m128i *from) { Write<A_rows, B_cols>(reinterpret_cast<const T*>(from)); }
- template <Index A_rows, Index B_cols> void Write(const __m256i *from) { Write<A_rows, B_cols>(reinterpret_cast<const T*>(from)); }
- template <Index A_rows, Index B_cols> void Write(const __m512i *from) { Write<A_rows, B_cols>(reinterpret_cast<const T*>(from)); }
+ template <Index A_rows, Index B_cols> void Write(const __m128i *from) { SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); }
+ template <Index A_rows, Index B_cols> void Write(const __m256i *from) { SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); }
+ template <Index A_rows, Index B_cols> void Write(const __m512i *from) {
+ WriteImpl<A_rows, B_cols, B_cols>(from);
+ }
+
+ private:
+ // If there's a full register to write for a column, do that.
+ template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW
+ typename std::enable_if<A_rows && B_cols && (CurColumn >= 16)>::type
+ WriteImpl(const __m512i *from) {
+ _mm512_storeu_si512(data_, *from);
+ Add(0, 16).template WriteImpl<A_rows, B_cols, (CurColumn - 16)>(from + 1);
+ }
+
+ // There is a mix of rows in a register and we need a scatter. Also, we are lucky enough to be at the beginning of a row.
+ // TODO case where we are not at the beginning of a register with _mm512_alignr_epi32.
+ template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW
+ typename std::enable_if<(A_rows > 1) && CurColumn && (CurColumn < 16) && (CurColumn == B_cols) /* beginning of row */>::type
+ WriteImpl(const __m512i *from) {
+ // TODO: test this more, also save it somewhere! Make sure compiler isn't recreating this every time.
+ const __m512i coefficients_lo = _mm512_set_epi32(
+ 15 / B_cols, 14 / B_cols, 13 / B_cols, 12 / B_cols,
+ 11 / B_cols, 10 / B_cols, 9 / B_cols, 8 / B_cols,
+ 7 / B_cols, 6 / B_cols, 5 / B_cols, 4 / B_cols,
+ 3 / B_cols, 2 / B_cols, 1 / B_cols, 0);
+ __m512i cols_reg = _mm512_set1_epi32(cols_);
+ // Multiply by the number of columns for the offsets.
+ const __m512i multiplied = _mm512_mullo_epi32(cols_reg, coefficients_lo);
+ // Add row offsets.
+ const __m512i row_offsets = _mm512_set_epi32(
+ 15 % B_cols, 14 % B_cols, 13 % B_cols, 12 % B_cols,
+ 11 % B_cols, 10 % B_cols, 9 % B_cols, 8 % B_cols,
+ 7 % B_cols, 6 % B_cols, 5 % B_cols, 4 % B_cols,
+ 3 % B_cols, 2 % B_cols, 1 % B_cols, 0);
+ // These are the offsets to use if we're perfectly aligned: B_cols is a divisor or multiple of 16.
+ __m512i offsets = _mm512_add_epi32(row_offsets, multiplied);
+ // We might be at the end of the data, in which case a mask is needed.
+ constexpr Index remaining = (A_rows - 1) * B_cols + CurColumn;
+ _mm512_mask_i32scatter_epi32(data_, (1 << remaining) - 1, offsets, *from, sizeof(int32_t));
+ // We just wrote 16 values: CurColumn, the next row (all or partial), possibly the next etc.
+ // 16 - CurColumn of the next row and whatever followed.
+ constexpr Index WroteMore = ((remaining < 16) ? remaining : 16) - CurColumn;
+ // TODO: testing on this.
+ Add(WroteMore / B_cols, WroteMore % B_cols - CurColumn).template WriteImpl<A_rows - WroteMore / B_cols, B_cols, WroteMore % B_cols>(from);
+ }
- template <Index A_rows, Index B_cols> void Write(const T *from) {
+ // At clean end of column, move to next row.
+ template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW
+ typename std::enable_if<A_rows && B_cols && (CurColumn == 0)>::type
+ WriteImpl(const __m512i *from) {
+ Add(1, -B_cols).template WriteImpl<A_rows - 1, B_cols, B_cols>(from);
+ }
+
+ // On the last row, finish the last write with a mask.
+ template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW
+ typename std::enable_if<(A_rows == 1) && B_cols && (CurColumn < 16 && CurColumn > 0)>::type
+ WriteImpl(const __m512i *from) {
+ _mm512_mask_storeu_epi32(data_, (1 << CurColumn) - 1, *from);
+ }
+
+ // Nothing to write.
+ template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW
+ typename std::enable_if<!A_rows || !B_cols>::type
+ WriteImpl(const __m512i *) {}
+
+ template <Index A_rows, Index B_cols> void SlowWrite(const T *from) {
for (Index i = 0; i < A_rows; ++i) {
for (Index j = 0; j < B_cols; ++j) {
data_[i * cols_ + j] = from[i * B_cols + j];
@@ -33,7 +97,6 @@ template <class T> class RowMajorAccess {
}
}
- private:
Content *data_;
Index cols_;
};