diff options
author | Nikolay Bogoychev <nheart@gmail.com> | 2020-01-24 22:03:01 +0300 |
---|---|---|
committer | Nikolay Bogoychev <nheart@gmail.com> | 2020-01-24 22:03:01 +0300 |
commit | df77b4dc774fc148684f8790e08513c5aa458630 (patch) | |
tree | bf9bf6326315df711f231089aedb77b129a40b6a | |
parent | 87a3ae9a4bde6238d31a2b87fc22754f468828ff (diff) |
Rename function for easier templating
-rw-r--r-- | avx512vnni_gemm.h | 2 | ||||
-rw-r--r-- | benchmarks/biasmultiply.cc | 2 | ||||
-rw-r--r-- | intgemm.h | 8 | ||||
-rw-r--r-- | multiply.h | 2 | ||||
-rw-r--r-- | test/add127_test.cc | 10 |
5 files changed, 12 insertions, 12 deletions
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h index bfcb282..3f616a6 100644 --- a/avx512vnni_gemm.h +++ b/avx512vnni_gemm.h @@ -112,7 +112,7 @@ struct AVX512VNNI_8bit : public AVX512_8bit { } template <typename Callback> - INTGEMM_AVX512VNNI static void PrepareBiasFor8(const int8_t *B, Index width, Index B_cols, Callback callback) { + INTGEMM_AVX512VNNI static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { typedef __m512i Integer; assert(width % sizeof(Integer) == 0); assert(B_cols % 8 == 0); diff --git a/benchmarks/biasmultiply.cc b/benchmarks/biasmultiply.cc index ec8ca95..58d6ebf 100644 --- a/benchmarks/biasmultiply.cc +++ b/benchmarks/biasmultiply.cc @@ -40,7 +40,7 @@ std::chrono::duration<double> testNew(Index A_rows, Index width, Index B_cols) { AlignedVector<float> test_C(A_rows * B_cols); float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on - Routine::PrepareBiasFor8(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); auto start = std::chrono::system_clock::now(); Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); auto end = std::chrono::system_clock::now(); @@ -87,7 +87,7 @@ struct Unsupported_8bit { throw UnsupportedCPU(); } template<class Callback> - static void PrepareBiasFor8(const int8_t *, Index, Index, Callback) { + static void PrepareBias(const int8_t *, Index, Index, Callback) { throw UnsupportedCPU(); } static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) { @@ -239,7 +239,7 @@ struct Int8Mult { // Multiply C = A * B, presuming A and B have been prepared. static void (*Multiply)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback); static void (*Multiply8Shift)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback); - static void (*PrepareBiasFor8)(const int8_t *B, Index width, Index B_cols, Callback callback); + static void (*PrepareBias)(const int8_t *B, Index width, Index B_cols, Callback callback); }; template <typename Callback> @@ -249,7 +249,7 @@ template <class Callback> void (*Int8Mult<Callback>::Multiply8Shift)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply8Shift<Callback>, AVX512_8bit::Multiply8Shift<Callback>, AVX2_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift); template <class Callback> -void (*Int8Mult<Callback>::PrepareBiasFor8)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::PrepareBiasFor8<Callback>, AVX512_8bit::PrepareBiasFor8<Callback>, AVX2_8bit::PrepareBiasFor8<Callback>, SSSE3_8bit::PrepareBiasFor8<Callback>, SSSE3_8bit::PrepareBiasFor8<Callback>, Unsupported_8bit::PrepareBiasFor8); +void (*Int8Mult<Callback>::PrepareBias)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::PrepareBias<Callback>, AVX512_8bit::PrepareBias<Callback>, AVX2_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, Unsupported_8bit::PrepareBias); struct Int8 { using Integer = int8_t; @@ -333,7 +333,7 @@ struct Int8Shift { // unquant_mult is computed by (-1)*(alpha)*(alpha)/(127.0f); template<class Callback> static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { - Int8Mult<Callback>::PrepareBiasFor8(B, width, B_cols, callback); + Int8Mult<Callback>::PrepareBias(B, width, B_cols, callback); } static const char *const kName; @@ -197,7 +197,7 @@ template <typename Callback> target static void Multiply(const int16_t *A, const //An int8_prepbias version of the above code, using the add 127 technique #define INTGEMM_PREPAREBIASFOR8(Integer, target, cpu_type) \ - template <class Callback> target static void PrepareBiasFor8(const int8_t *B, Index width, Index B_cols, Callback callback) { \ + template <class Callback> target static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { \ assert(width % (sizeof(Integer) / sizeof(int8_t)) == 0); \ assert(B_cols % 8 == 0); \ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); \ diff --git a/test/add127_test.cc b/test/add127_test.cc index 86d630b..d1b850d 100644 --- a/test/add127_test.cc +++ b/test/add127_test.cc @@ -71,7 +71,7 @@ template <class Routine> void TestPrepareBias(Index rows, Index cols) { float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); - Routine::PrepareBiasFor8(B_prep.begin(), rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin())); + Routine::PrepareBias(B_prep.begin(), rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin())); int A_rows = 1; AlignedVector<int8_t> A_prep2(A_rows*rows); @@ -136,8 +136,8 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind * */ float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on - Routine::PrepareBiasFor8(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); - //Routine::PrepareBiasFor8(B.begin(), bias.begin(), alpha, width, B_cols); + Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + //Routine::PrepareBias(B.begin(), bias.begin(), alpha, width, B_cols); Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), @@ -190,7 +190,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt * Multiply8 shift multiplication */ float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on - Routine::PrepareBiasFor8(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), @@ -256,7 +256,7 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In //Now prepare Fast integer Bias - Routine::PrepareBiasFor8(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); // Reference INT VERSION HERE with ADD127 |