diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-02-07 19:07:28 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-03-27 17:08:05 +0300 |
commit | 810502d28d0e3177989f3296ca69c8bc2b01ac26 (patch) | |
tree | ca228a838d0778f165319b371648f3f7f164d529 | |
parent | 0d0fe4147ed7540a65f8d7a4b7dac626a5632d7b (diff) |
Unify references::MultiplyFF and references::Multiply
-rw-r--r-- | test/add127_test.cc | 6 | ||||
-rw-r--r-- | test/multiply_test.cc | 4 | ||||
-rw-r--r-- | test/test.h | 38 |
3 files changed, 23 insertions, 25 deletions
diff --git a/test/add127_test.cc b/test/add127_test.cc index d959b14..cec20c2 100644 --- a/test/add127_test.cc +++ b/test/add127_test.cc @@ -127,7 +127,7 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); @@ -184,7 +184,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin())); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); @@ -245,7 +245,7 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In // }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); /* diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 260dd76..0fd8231 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -293,7 +293,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) { return sum; }); @@ -346,7 +346,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); diff --git a/test/test.h b/test/test.h index 7de38e9..f145681 100644 --- a/test/test.h +++ b/test/test.h @@ -76,30 +76,28 @@ void Quantize(const float* input, Type* output, float quant_mult, Index size) { } } -// Multiply A(float) x B(float) -template <typename LambdaCallback> -void MultiplyFF(const float* A, const float* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) { - for (Index r = 0; r < A_rows; ++r) { - for (Index c = 0; c < B_cols; ++c) { - float sum = 0.0f; - for (Index k = 0; k < width; ++k) { - sum += A[r * width + k] * B[k * B_cols + c]; - } - C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols}); - } - } -} +/* + * Multiply C = A x B + * + * Notes: A and B has to be both integers or both floating points. + * + * Callback takes two arguments: + * - Intermediate value of multiplication 1 row times 1 column - it's int32_t or double based on types A and B. + * - Object containing information about position in output matrix - callbacks::OutputBufferInfo. + */ +template <typename TypeA, typename TypeB, typename TypeC, typename LambdaCallback, + typename std::enable_if< + (std::is_integral<TypeA>::value && std::is_integral<TypeB>::value) || + (std::is_floating_point<TypeA>::value && std::is_floating_point<TypeB>::value) + >::type* = nullptr> +void Multiply(const TypeA* A, const TypeB* B, TypeC* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) { + using IntermediateType = typename std::conditional<std::is_integral<TypeA>::value, int32_t, double>::type; -// Multiply A(int) x B(int) -template <typename TypeA, typename TypeB, typename LambdaCallback, - typename std::enable_if<std::is_integral<TypeA>::value>::type* = nullptr, - typename std::enable_if<std::is_integral<TypeB>::value>::type* = nullptr> -void Multiply(const TypeA* A, const TypeB* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) { for (Index r = 0; r < A_rows; ++r) { for (Index c = 0; c < B_cols; ++c) { - int32_t sum = 0; + IntermediateType sum = 0; for (Index k = 0; k < width; ++k) { - sum += int32_t(A[r * width + k]) * int32_t(B[k * B_cols + c]); + sum += IntermediateType(A[r * width + k]) * IntermediateType(B[k * B_cols + c]); } C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols}); } |