Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/test.h')
-rw-r--r--test/test.h89
1 files changed, 80 insertions, 9 deletions
diff --git a/test/test.h b/test/test.h
index 291ff45..7c294f8 100644
--- a/test/test.h
+++ b/test/test.h
@@ -1,11 +1,15 @@
#pragma once
+#include "intgemm_config.h"
+
#include "../3rd_party/catch.hpp"
-#include <sstream>
#include "../intgemm.h"
#include "../aligned.h"
-#include "intgemm_config.h"
+#include <math.h>
+#include <sstream>
+#include <iostream>
+#include <iomanip>
#define CHECK_MESSAGE(cond, msg) do { INFO(msg); CHECK(cond); } while(0)
#define CHECK_FALSE_MESSAGE(cond, msg) do { INFO(msg); CHECK_FALSE(cond); } while(0)
@@ -21,13 +25,80 @@
#define KERNEL_TEST_CASE(name) TEST_CASE("Kernel: " name, "[kernel_test]")
namespace intgemm {
-void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias=nullptr);
-// Compute A*B slowly from integers.
-template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias=nullptr);
-void SlowRefInt(const uint8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias=nullptr);
+void Compare(const float *float_ref, const float *int_ref, const float *int_test,
+ std::size_t size, std::string test_info, float int_tolerance,
+ float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance);
+
+/*
+ * References
+ */
+namespace references {
+
+// Quantize
+template <typename Type>
+void Quantize(const float* input, Type* output, float quant_mult, Index size) {
+ for (Index i = 0; i < size; ++i) {
+ float value = roundf(input[i] * quant_mult);
+ value = std::max<float>(std::numeric_limits<Type>::min(), value);
+ value = std::min<float>(std::numeric_limits<Type>::max(), value);
+ output[i] = value;
+ }
+}
+
+// Multiply A(float) x B(float)
+template <typename LambdaCallback>
+void MultiplyFF(const float* A, const float* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
+ for (Index r = 0; r < A_rows; ++r) {
+ for (Index c = 0; c < B_cols; ++c) {
+ float sum = 0.0f;
+ for (Index k = 0; k < width; ++k) {
+ sum += A[r * width + k] * B[k * B_cols + c];
+ }
+ C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols});
+ }
+ }
+}
+
+// Multiply A(int) x B(int)
+template <typename TypeA, typename TypeB, typename LambdaCallback,
+ typename std::enable_if<std::is_integral<TypeA>::value>::type* = nullptr,
+ typename std::enable_if<std::is_integral<TypeB>::value>::type* = nullptr>
+void Multiply(const TypeA* A, const TypeB* B, float* C, Index A_rows, Index width, Index B_cols, LambdaCallback callback) {
+ for (Index r = 0; r < A_rows; ++r) {
+ for (Index c = 0; c < B_cols; ++c) {
+ int32_t sum = 0;
+ for (Index k = 0; k < width; ++k) {
+ sum += int32_t(A[r * width + k]) * int32_t(B[k * B_cols + c]);
+ }
+ C[r * B_cols + c] = callback(sum, {r, c, A_rows, B_cols});
+ }
+ }
+}
+
+// Matrix rearragement
+template <typename Type>
+void Rearragement(const Type* input, Type* output, int simd, int unroll, Index rows, Index cols) {
+ for (Index c = 0; c < cols; c += unroll) {
+ for (Index r = 0; r < rows; r += simd) {
+ for (Index i = 0; i < unroll; ++i)
+ for (Index j = 0; j < simd; ++j)
+ output[simd * i + j] = input[cols * r + c + cols * j + i];
+
+ output += unroll * simd;
+ }
+ }
+}
-void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info,
- float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance);
+// Transpose
+template <typename Type>
+void Transpose(const Type* input, Type* output, Index rows, Index cols) {
+ for (Index r = 0; r < rows; ++r) {
+ for (Index c = 0; c < cols; ++c) {
+ output[rows * c + r] = input[cols * r + c];
+ }
+ }
+}
-} //namespace intgemm
+} // namespace references
+} // namespace intgemm