diff options
Diffstat (limited to 'test/test.cc')
-rw-r--r-- | test/test.cc | 56 |
1 files changed, 2 insertions, 54 deletions
diff --git a/test/test.cc b/test/test.cc index 2986d82..62137a1 100644 --- a/test/test.cc +++ b/test/test.cc @@ -7,60 +7,8 @@ int main(int argc, char ** argv) { namespace intgemm { -void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols, const float *bias) { - for (Index r = 0; r < A_rows; ++r) { - for (Index c = 0; c < B_cols; ++c) { - float sum = 0.0f; - for (Index k = 0; k < width; ++k) { - sum += A[r * width + k] * B[k * B_cols + c]; - } - if (bias) { - C[r * B_cols + c] = sum + bias[c]; - } else { - C[r * B_cols + c] = sum; - } - } - } -} - -// Compute A*B slowly from integers. -template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias) { - for (Index r = 0; r < A_rows; ++r) { - for (Index c = 0; c < B_cols; ++c) { - int32_t sum = 0; - for (Index k = 0; k < width; ++k) { - sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]); - } - if (bias) { - C[r * B_cols + c] = sum * unquant_mult + bias[c]; - } else { - C[r * B_cols + c] = sum * unquant_mult; - } - } - } -} -void SlowRefInt(const uint8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias) { - for (Index r = 0; r < A_rows; ++r) { - for (Index c = 0; c < B_cols; ++c) { - int32_t sum = 0; - for (Index k = 0; k < width; ++k) { - sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]); - } - if (bias) { - C[r * B_cols + c] = sum * unquant_mult + bias[c]; - } else { - C[r * B_cols + c] = sum * unquant_mult; - } - } - } -} - -template void SlowRefInt<int8_t>(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias); -template void SlowRefInt<int16_t>(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias); -template void SlowRefInt<int32_t>(const int32_t *A, const int32_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias); - void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info, - float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) { + float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) { float int_sum = 0.0, float_sum = 0.0; for (std::size_t i = 0; i < size; ++i) { float int_diff = int_ref[i] - int_test[i]; @@ -74,4 +22,4 @@ void Compare(const float *float_ref, const float *int_ref, const float *int_test CHECK_MESSAGE(fabs(sqrt(int_sum / size)) <= MSE_int_tolerance, test_info << "Int MSE = " << sqrt(int_sum / size)); } -} //namespace intgemm +} // namespace intgemm |