diff options
author | Kenneth Heafield <kpu@users.noreply.github.com> | 2020-03-27 18:08:21 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-27 18:08:21 +0300 |
commit | 6779e6aac623422a66ddc02c4fb1ba73927fdad2 (patch) | |
tree | 416a65b3744b950b0ebbccf168070b55fbc10717 | |
parent | 0d0fe4147ed7540a65f8d7a4b7dac626a5632d7b (diff) | |
parent | dcf3a5ebc62849b823e38400873e2de81c8d2e9c (diff) |
Merge pull request #74 from kpu/unify-reference-mulitplies-funs
Unify reference mulitplies funs
-rw-r--r-- | test/add127_test.cc | 6 | ||||
-rw-r--r-- | test/multiply_test.cc | 4 | ||||
-rw-r--r-- | test/test.h | 38 | ||||
-rw-r--r-- | test/utils_test.cc | 7 | ||||
-rw-r--r-- | utils.h | 15 |
5 files changed, 40 insertions, 30 deletions
diff --git a/test/add127_test.cc b/test/add127_test.cc index d959b14..cec20c2 100644 --- a/test/add127_test.cc +++ b/test/add127_test.cc @@ -127,7 +127,7 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); @@ -184,7 +184,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin())); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); @@ -245,7 +245,7 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In // }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); /* diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 260dd76..0fd8231 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -293,7 +293,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) { return sum; }); @@ -346,7 +346,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index }); AlignedVector<float> float_C(test_C.size()); - references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { return sum + bias[info.col_idx]; }); diff --git a/test/test.h b/test/test.h index 7de38e9..f145681 100644 --- a/test/test.h +++ b/test/test.h @@ -76,30 +76,28 @@ void Quantize(const float* input, Type* output, float quant_mult, Index size) { } } -// Multiply A(float) x B(float) -template <typename LambdaCallback> -void MultiplyFF(const float* A, const float* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) { - for (Index r = 0; r < A_rows; ++r) { - for (Index c = 0; c < B_cols; ++c) { - float sum = 0.0f; - for (Index k = 0; k < width; ++k) { - sum += A[r * width + k] * B[k * B_cols + c]; - } - C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols}); - } - } -} +/* + * Multiply C = A x B + * + * Notes: A and B has to be both integers or both floating points. + * + * Callback takes two arguments: + * - Intermediate value of multiplication 1 row times 1 column - it's int32_t or double based on types A and B. + * - Object containing information about position in output matrix - callbacks::OutputBufferInfo. + */ +template <typename TypeA, typename TypeB, typename TypeC, typename LambdaCallback, + typename std::enable_if< + (std::is_integral<TypeA>::value && std::is_integral<TypeB>::value) || + (std::is_floating_point<TypeA>::value && std::is_floating_point<TypeB>::value) + >::type* = nullptr> +void Multiply(const TypeA* A, const TypeB* B, TypeC* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) { + using IntermediateType = typename std::conditional<std::is_integral<TypeA>::value, int32_t, double>::type; -// Multiply A(int) x B(int) -template <typename TypeA, typename TypeB, typename LambdaCallback, - typename std::enable_if<std::is_integral<TypeA>::value>::type* = nullptr, - typename std::enable_if<std::is_integral<TypeB>::value>::type* = nullptr> -void Multiply(const TypeA* A, const TypeB* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) { for (Index r = 0; r < A_rows; ++r) { for (Index c = 0; c < B_cols; ++c) { - int32_t sum = 0; + IntermediateType sum = 0; for (Index k = 0; k < width; ++k) { - sum += int32_t(A[r * width + k]) * int32_t(B[k * B_cols + c]); + sum += IntermediateType(A[r * width + k]) * IntermediateType(B[k * B_cols + c]); } C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols}); } diff --git a/test/utils_test.cc b/test/utils_test.cc index 8596104..9c2fb06 100644 --- a/test/utils_test.cc +++ b/test/utils_test.cc @@ -78,5 +78,12 @@ TEST_CASE("Static loop with mult-dim iterator (Iterator<5, 2>)",) { CHECK(result == 11223344); } +TEST_CASE("Round up",) { + CHECK(round_up(0, 5) == 0); + CHECK(round_up(1, 5) == 5); + CHECK(round_up(4, 5) == 5); + CHECK(round_up(6, 5) == 10); +} + } } @@ -52,20 +52,18 @@ constexpr subtuple_t<Tuple, Indices...> make_subtuple(const Tuple& tuple, sequen /* * Factorial */ -constexpr unsigned long long factorial(unsigned n) { +static constexpr unsigned long long factorial(unsigned n) { return n <= 1 ? 1 : n * factorial(n - 1); } /* * e^n, where n is integer */ -namespace { // anonymous namespace -constexpr double expi_nonnegative(unsigned n) { +static constexpr double expi_nonnegative(unsigned n) { return n == 0 ? 1.0 : (n == 1 ? 2.718281828459045 : expi_nonnegative(n / 2) * expi_nonnegative((n + 1) / 2)); } -} // anonymous namespace -constexpr double expi(int n) { +static constexpr double expi(int n) { return (n >= 0 ? expi_nonnegative(n) : 1.0 / expi_nonnegative(-n)); } @@ -201,4 +199,11 @@ __attribute__((always_inline)) static inline void StaticLoop(Args&&... args) { StaticLoop<Body, typename StaticLoopIterator::next>(std::forward<Args>(args)...); } +/* + * Round up + */ +static constexpr Index round_up(Index value, Index factor) { + return (value + factor - 1) / factor * factor; +} + } |