Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2020-02-07 19:07:28 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2020-03-27 17:08:05 +0300
commit810502d28d0e3177989f3296ca69c8bc2b01ac26 (patch)
treeca228a838d0778f165319b371648f3f7f164d529
parent0d0fe4147ed7540a65f8d7a4b7dac626a5632d7b (diff)
Unify references::MultiplyFF and references::Multiply
-rw-r--r--test/add127_test.cc6
-rw-r--r--test/multiply_test.cc4
-rw-r--r--test/test.h38
3 files changed, 23 insertions, 25 deletions
diff --git a/test/add127_test.cc b/test/add127_test.cc
index d959b14..cec20c2 100644
--- a/test/add127_test.cc
+++ b/test/add127_test.cc
@@ -127,7 +127,7 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind
});
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
@@ -184,7 +184,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt
Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin()));
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
@@ -245,7 +245,7 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In
// });
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
/*
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 260dd76..0fd8231 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -293,7 +293,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
});
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) {
return sum;
});
@@ -346,7 +346,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
});
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
return sum + bias[info.col_idx];
});
diff --git a/test/test.h b/test/test.h
index 7de38e9..f145681 100644
--- a/test/test.h
+++ b/test/test.h
@@ -76,30 +76,28 @@ void Quantize(const float* input, Type* output, float quant_mult, Index size) {
}
}
-// Multiply A(float) x B(float)
-template <typename LambdaCallback>
-void MultiplyFF(const float* A, const float* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
- for (Index r = 0; r < A_rows; ++r) {
- for (Index c = 0; c < B_cols; ++c) {
- float sum = 0.0f;
- for (Index k = 0; k < width; ++k) {
- sum += A[r * width + k] * B[k * B_cols + c];
- }
- C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols});
- }
- }
-}
+/*
+ * Multiply C = A x B
+ *
+ * Notes: A and B has to be both integers or both floating points.
+ *
+ * Callback takes two arguments:
+ * - Intermediate value of multiplication 1 row times 1 column - it's int32_t or double based on types A and B.
+ * - Object containing information about position in output matrix - callbacks::OutputBufferInfo.
+ */
+template <typename TypeA, typename TypeB, typename TypeC, typename LambdaCallback,
+ typename std::enable_if<
+ (std::is_integral<TypeA>::value && std::is_integral<TypeB>::value) ||
+ (std::is_floating_point<TypeA>::value && std::is_floating_point<TypeB>::value)
+ >::type* = nullptr>
+void Multiply(const TypeA* A, const TypeB* B, TypeC* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
+ using IntermediateType = typename std::conditional<std::is_integral<TypeA>::value, int32_t, double>::type;
-// Multiply A(int) x B(int)
-template <typename TypeA, typename TypeB, typename LambdaCallback,
- typename std::enable_if<std::is_integral<TypeA>::value>::type* = nullptr,
- typename std::enable_if<std::is_integral<TypeB>::value>::type* = nullptr>
-void Multiply(const TypeA* A, const TypeB* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
for (Index r = 0; r < A_rows; ++r) {
for (Index c = 0; c < B_cols; ++c) {
- int32_t sum = 0;
+ IntermediateType sum = 0;
for (Index k = 0; k < width; ++k) {
- sum += int32_t(A[r * width + k]) * int32_t(B[k * B_cols + c]);
+ sum += IntermediateType(A[r * width + k]) * IntermediateType(B[k * B_cols + c]);
}
C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols});
}