diff options
Diffstat (limited to 'bench/FP16Benchmark.cc')
-rw-r--r-- | bench/FP16Benchmark.cc | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/bench/FP16Benchmark.cc b/bench/FP16Benchmark.cc index c03f18a..fd9de5b 100644 --- a/bench/FP16Benchmark.cc +++ b/bench/FP16Benchmark.cc @@ -73,20 +73,24 @@ void performance_test() { int n = s[1]; int k = s[2]; - aligned_vector<float> A(m * k, 0.f); - aligned_vector<float> B(k * n, 0.f); - aligned_vector<float> Cg(m * n, 1.f); - aligned_vector<float> Cp(m * n, NAN); + aligned_vector<float> C_ref(m * n, 1.f); + aligned_vector<float> C_fb(m * n, NAN); // initialize with small numbers - randFill(A, 0, 4); + aligned_vector<int> Aint(m * k); + randFill(Aint, 0, 4); + aligned_vector<float> A(Aint.begin(), Aint.end()); - randFill(B, 0, 4); + aligned_vector<int> Bint(k * n); + randFill(Bint, 0, 4); + aligned_vector<float> B(Bint.begin(), Bint.end()); PackedGemmMatrixFP16 Bp(btran, k, n, alpha, B.data()); if (beta != 0.0f) { - randFill(Cg, 0, 4); - Cp = Cg; + aligned_vector<int> Cint(C_ref.size()); + randFill(Cint, 0, 4); + C_ref.assign(Cint.begin(), Cint.end()); + C_fb = C_ref; } double nflops = 2.0 * (double)m * (double)n * (double)k * (double)NITER; @@ -111,17 +115,17 @@ void performance_test() { B.data(), (btran == matrix_op_t::NoTranspose) ? n : k, beta, - Cg.data(), + C_ref.data(), n); #endif cblas_gemm_compute( - matrix_op_t::NoTranspose, m, A.data(), Bp, beta, Cp.data()); + matrix_op_t::NoTranspose, m, A.data(), Bp, beta, C_fb.data()); #ifdef USE_MKL // Compare results - for (auto i = 0; i < Cg.size(); i++) { - // printf("%f %f\n", Cg[i], Cp[i]); - assert(std::abs(Cg[i] - Cp[i]) < 1e-3); + for (auto i = 0; i < C_ref.size(); i++) { + // printf("%f %f\n", C_ref[i], C_fb[i]); + assert(std::abs(C_ref[i] - C_fb[i]) < 1e-3); } #endif } @@ -151,7 +155,7 @@ void performance_test() { B.data(), (btran == matrix_op_t::NoTranspose) ? n : k, beta, - Cg.data(), + C_ref.data(), n); t_end = chrono::system_clock::now(); if (it >= 0) { @@ -184,7 +188,7 @@ void performance_test() { t_begin = chrono::system_clock::now(); cblas_gemm_compute( - matrix_op_t::NoTranspose, m, A.data(), Bp, beta, Cp.data()); + matrix_op_t::NoTranspose, m, A.data(), Bp, beta, C_fb.data()); t_end = chrono::system_clock::now(); if (it >= 0) { |