Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <kpu@users.noreply.github.com>2020-03-27 18:08:21 +0300
committerGitHub <noreply@github.com>2020-03-27 18:08:21 +0300
commit6779e6aac623422a66ddc02c4fb1ba73927fdad2 (patch)
tree416a65b3744b950b0ebbccf168070b55fbc10717
parent0d0fe4147ed7540a65f8d7a4b7dac626a5632d7b (diff)
parentdcf3a5ebc62849b823e38400873e2de81c8d2e9c (diff)
Merge pull request #74 from kpu/unify-reference-mulitplies-funs
Unify reference mulitplies funs
-rw-r--r--test/add127_test.cc6
-rw-r--r--test/multiply_test.cc4
-rw-r--r--test/test.h38
-rw-r--r--test/utils_test.cc7
-rw-r--r--utils.h15
5 files changed, 40 insertions, 30 deletions
diff --git a/test/add127_test.cc b/test/add127_test.cc
index d959b14..cec20c2 100644
--- a/test/add127_test.cc
+++ b/test/add127_test.cc
@@ -127,7 +127,7 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind
});
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
@@ -184,7 +184,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt
Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin()));
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
@@ -245,7 +245,7 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In
// });
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
/*
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 260dd76..0fd8231 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -293,7 +293,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
});
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) {
return sum;
});
@@ -346,7 +346,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
});
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
diff --git a/test/test.h b/test/test.h
index 7de38e9..f145681 100644
--- a/test/test.h
+++ b/test/test.h
@@ -76,30 +76,28 @@ void Quantize(const float* input, Type* output, float quant_mult, Index size) {
}
}
-// Multiply A(float) x B(float)
-template <typename LambdaCallback>
-void MultiplyFF(const float* A, const float* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
- for (Index r = 0; r < A_rows; ++r) {
- for (Index c = 0; c < B_cols; ++c) {
- float sum = 0.0f;
- for (Index k = 0; k < width; ++k) {
- sum += A[r * width + k] * B[k * B_cols + c];
- }
- C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols});
- }
- }
-}
+/*
+ * Multiply C = A x B
+ *
+ * Notes: A and B has to be both integers or both floating points.
+ *
+ * Callback takes two arguments:
+ * - Intermediate value of multiplication 1 row times 1 column - it's int32_t or double based on types A and B.
+ * - Object containing information about position in output matrix - callbacks::OutputBufferInfo.
+ */
+template <typename TypeA, typename TypeB, typename TypeC, typename LambdaCallback,
+ typename std::enable_if<
+ (std::is_integral<TypeA>::value && std::is_integral<TypeB>::value) ||
+ (std::is_floating_point<TypeA>::value && std::is_floating_point<TypeB>::value)
+ >::type* = nullptr>
+void Multiply(const TypeA* A, const TypeB* B, TypeC* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
+ using IntermediateType = typename std::conditional<std::is_integral<TypeA>::value, int32_t, double>::type;
-// Multiply A(int) x B(int)
-template <typename TypeA, typename TypeB, typename LambdaCallback,
- typename std::enable_if<std::is_integral<TypeA>::value>::type* = nullptr,
- typename std::enable_if<std::is_integral<TypeB>::value>::type* = nullptr>
-void Multiply(const TypeA* A, const TypeB* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
for (Index r = 0; r < A_rows; ++r) {
for (Index c = 0; c < B_cols; ++c) {
- int32_t sum = 0;
+ IntermediateType sum = 0;
for (Index k = 0; k < width; ++k) {
- sum += int32_t(A[r * width + k]) * int32_t(B[k * B_cols + c]);
+ sum += IntermediateType(A[r * width + k]) * IntermediateType(B[k * B_cols + c]);
}
C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols});
}
diff --git a/test/utils_test.cc b/test/utils_test.cc
index 8596104..9c2fb06 100644
--- a/test/utils_test.cc
+++ b/test/utils_test.cc
@@ -78,5 +78,12 @@ TEST_CASE("Static loop with mult-dim iterator (Iterator<5, 2>)",) {
CHECK(result == 11223344);
}
+TEST_CASE("Round up",) {
+ CHECK(round_up(0, 5) == 0);
+ CHECK(round_up(1, 5) == 5);
+ CHECK(round_up(4, 5) == 5);
+ CHECK(round_up(6, 5) == 10);
+}
+
}
}
diff --git a/utils.h b/utils.h
index 7fa2f6e..94b16d3 100644
--- a/utils.h
+++ b/utils.h
@@ -52,20 +52,18 @@ constexpr subtuple_t<Tuple, Indices...> make_subtuple(const Tuple& tuple, sequen
/*
* Factorial
*/
-constexpr unsigned long long factorial(unsigned n) {
+static constexpr unsigned long long factorial(unsigned n) {
return n <= 1 ? 1 : n * factorial(n - 1);
}
/*
* e^n, where n is integer
*/
-namespace { // anonymous namespace
-constexpr double expi_nonnegative(unsigned n) {
+static constexpr double expi_nonnegative(unsigned n) {
return n == 0 ? 1.0 : (n == 1 ? 2.718281828459045 : expi_nonnegative(n / 2) * expi_nonnegative((n + 1) / 2));
}
-} // anonymous namespace
-constexpr double expi(int n) {
+static constexpr double expi(int n) {
return (n >= 0 ? expi_nonnegative(n) : 1.0 / expi_nonnegative(-n));
}
@@ -201,4 +199,11 @@ __attribute__((always_inline)) static inline void StaticLoop(Args&&... args) {
StaticLoop<Body, typename StaticLoopIterator::next>(std::forward<Args>(args)...);
}
+/*
+ * Round up
+ */
+static constexpr Index round_up(Index value, Index factor) {
+ return (value + factor - 1) / factor * factor;
+}
+
}