Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/test.cc')
-rw-r--r--test/test.cc56
1 files changed, 2 insertions, 54 deletions
diff --git a/test/test.cc b/test/test.cc
index 2986d82..62137a1 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -7,60 +7,8 @@ int main(int argc, char ** argv) {
namespace intgemm {
-void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias) {
- for (Index r = 0; r < A_rows; ++r) {
- for (Index c = 0; c < B_cols; ++c) {
- float sum = 0.0f;
- for (Index k = 0; k < width; ++k) {
- sum += A[r * width + k] * B[k * B_cols + c];
- }
- if (bias) {
- C[r * B_cols + c] = sum + bias[c];
- } else {
- C[r * B_cols + c] = sum;
- }
- }
- }
-}
-
-// Compute A*B slowly from integers.
-template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias) {
- for (Index r = 0; r < A_rows; ++r) {
- for (Index c = 0; c < B_cols; ++c) {
- int32_t sum = 0;
- for (Index k = 0; k < width; ++k) {
- sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]);
- }
- if (bias) {
- C[r * B_cols + c] = sum * unquant_mult + bias[c];
- } else {
- C[r * B_cols + c] = sum * unquant_mult;
- }
- }
- }
-}
-void SlowRefInt(const uint8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias) {
- for (Index r = 0; r < A_rows; ++r) {
- for (Index c = 0; c < B_cols; ++c) {
- int32_t sum = 0;
- for (Index k = 0; k < width; ++k) {
- sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]);
- }
- if (bias) {
- C[r * B_cols + c] = sum * unquant_mult + bias[c];
- } else {
- C[r * B_cols + c] = sum * unquant_mult;
- }
- }
- }
-}
-
-template void SlowRefInt<int8_t>(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias);
-template void SlowRefInt<int16_t>(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias);
-template void SlowRefInt<int32_t>(const int32_t *A, const int32_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias);
-
void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info,
- float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) {
+ float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) {
float int_sum = 0.0, float_sum = 0.0;
for (std::size_t i = 0; i < size; ++i) {
float int_diff = int_ref[i] - int_test[i];
@@ -74,4 +22,4 @@ void Compare(const float *float_ref, const float *int_ref, const float *int_test
CHECK_MESSAGE(fabs(sqrt(int_sum / size)) <= MSE_int_tolerance, test_info << "Int MSE = " << sqrt(int_sum / size));
}
-} //namespace intgemm
+} // namespace intgemm