Welcome to mirror list, hosted at ThFree Co, Russian Federation.

prepare_b_transposed.cc « test - github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 5969724844d0af4bc78527998c3fe8186e90677c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include "test.h"
#include "../aligned.h"
#include "../avx2_gemm.h"
#include "../avx512_gemm.h"
#include "../sse2_gemm.h"
#include "../ssse3_gemm.h"

#include <cstring>
#include <iostream>
#include <math.h>

namespace intgemm {
namespace {

template <typename Backend>
void PrepareBTransposedRef(const float* input, typename Backend::Integer* output, float quant_mult, Index B_transposed_cols, Index B_transposed_rows) {
  using vec_t = intgemm::vector_t<Backend::kUses, typename Backend::Integer>;
  constexpr Index vec_len = sizeof(vec_t) / sizeof(typename Backend::Integer);

  for (Index i = 0; i < B_transposed_rows * B_transposed_cols / 8; i += vec_len)
    for (Index j = 0; j < 8; ++j)
      for (Index k = 0; k < vec_len; ++k) {
        Index col = (i + k) % B_transposed_cols;
        Index row = 8 * ((i + k) / B_transposed_cols) + j;
        *output++ = static_cast<typename Backend::Integer>(input[row * B_transposed_cols + col] * quant_mult);
      }
}

template <typename Backend>
bool Test(const AlignedVector<float>& input, Index B_rows, Index B_cols, float quant_mult) {
  bool success = true;

  AlignedVector<typename Backend::Integer> output(input.size());
  Backend::PrepareBTransposed(input.begin(), output.begin(), quant_mult, B_rows, B_cols);

  AlignedVector<typename Backend::Integer> reference(input.size());
  PrepareBTransposedRef<Backend>(input.begin(), reference.begin(), quant_mult, B_rows, B_cols);

  for (std::size_t i = 0; i < output.size(); ++i) {
    if (output[i] != reference[i]) {
      UNSCOPED_INFO("Error at " << i << ", output = " << int(output[i]) << ", reference = " << int(reference[i]));
      success = false;
      break;
    }
  }
  return success;
}

template <typename Backend>
bool TestMany(Index B_rows, Index B_cols, float quant_mult) {
  AlignedVector<float> input(B_rows * B_cols);

  std::generate(input.begin(), input.end(), []() {
    static constexpr int divider = sizeof(intgemm::vector_t<Backend::kUses, typename Backend::Integer>) / sizeof(typename Backend::Integer);
    static int value = 0;
    return static_cast<float>((value++) % divider);
  });

  return Test<Backend>(input, B_rows, B_cols, quant_mult);
}

TEST_CASE("PrepareBTransposed SSE2", "") {
  if (kCPU < CPUType::SSE2)
    return;

  CHECK(TestMany<SSE2_16bit>(4, 128, 2.0f));
}

TEST_CASE("PrepareBTransposed SSSE3", "") {
  if (kCPU < CPUType::SSSE3)
    return;

  CHECK(TestMany<SSSE3_8bit>(4, 128, 2.0f));
}

TEST_CASE("PrepareBTransposed AVX2", "") {
  if (kCPU < CPUType::AVX2)
    return;

  CHECK(TestMany<AVX2_8bit>(8, 128, 2.0f));
  CHECK(TestMany<AVX2_16bit>(8, 128, 2.0f));
}

#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
  TEST_CASE("PrepareBTransposed AVX512", "") {
    if (kCPU < CPUType::AVX512BW)
      return;

    CHECK(TestMany<AVX512_8bit>(16, 128, 2.0f));
    CHECK(TestMany<AVX512_16bit>(16, 128, 2.0f));
  }
#endif

}
}