diff options
author | dskhudia <dskhudia@fb.com> | 2018-11-03 21:05:52 +0300 |
---|---|---|
committer | dskhudia <dskhudia@fb.com> | 2018-11-03 21:05:52 +0300 |
commit | 505eb847185c9255526813dd39edadcd4e61d8e0 (patch) | |
tree | 480e3b4f3125d8b0f39047464ae7a7b233585f49 | |
parent | e85b5a12254fa47ca6b56236489253a68fd32104 (diff) |
Manually syncing with internal copy
-rw-r--r-- | CMakeLists.txt | 16 | ||||
-rw-r--r-- | bench/Im2ColFusedRequantizeAcc16Benchmark.cc | 230 | ||||
-rw-r--r-- | bench/Im2ColFusedRequantizeAcc32Benchmark.cc | 227 | ||||
-rw-r--r-- | src/PackAWithIm2Col.cc | 60 | ||||
-rw-r--r-- | src/RefImplementations.cc | 42 | ||||
-rw-r--r-- | test/I8SpmdmTest.cc | 2 | ||||
-rw-r--r-- | test/Im2ColFusedRequantizeTest.cc | 210 |
7 files changed, 657 insertions, 130 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index cfc47b5..8a477d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.7 FATAL_ERROR) +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") @@ -96,17 +96,17 @@ if(NOT TARGET asmjit) endif() if(NOT TARGET cpuinfo) - #Download cpuinfo from github if CPUINFO_SRC_DIR is not specified. - if(NOT DEFINED CPUINFO_SRC_DIR) + #Download cpuinfo from github if CPUINFO_SOURCE_DIR is not specified. + if(NOT DEFINED CPUINFO_SOURCE_DIR) message(STATUS "Downloading cpuinfo to ${FBGEMM_THIRDPARTY_DIR}/cpuinfo - (define CPUINFO_SRC_DIR to avoid it)") + (define CPUINFO_SOURCE_DIR to avoid it)") configure_file("${FBGEMM_SOURCE_DIR}/cmake/modules/DownloadCPUINFO.cmake" "${FBGEMM_BINARY_DIR}/cpuinfo-download/CMakeLists.txt") execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/cpuinfo-download") execute_process(COMMAND "${CMAKE_COMMAND}" --build . WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/cpuinfo-download") - set(CPUINFO_SRC_DIR "${FBGEMM_THIRDPARTY_DIR}/cpuinfo" CACHE STRING + set(CPUINFO_SOURCE_DIR "${FBGEMM_THIRDPARTY_DIR}/cpuinfo" CACHE STRING "cpuinfo source directory") endif() @@ -115,7 +115,7 @@ if(NOT TARGET cpuinfo) set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "Do not build cpuinfo mock tests") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "Do not build cpuinfo benchmarks") set(CPUINFO_LIBRARY_TYPE static) - add_subdirectory("${CPUINFO_SRC_DIR}" "${FBGEMM_BINARY_DIR}/cpuinfo") + add_subdirectory("${CPUINFO_SOURCE_DIR}" "${FBGEMM_BINARY_DIR}/cpuinfo") set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) endif() @@ -123,13 +123,13 @@ target_include_directories(fbgemm_avx2 BEFORE PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}> PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include> PRIVATE "${ASMJIT_SRC_DIR}/src" - PRIVATE "${CPUINFO_SRC_DIR}/include") + PRIVATE "${CPUINFO_SOURCE_DIR}/include") target_include_directories(fbgemm_avx512 BEFORE PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}> PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include> PRIVATE "${ASMJIT_SRC_DIR}/src" - PRIVATE "${CPUINFO_SRC_DIR}/include") + PRIVATE "${CPUINFO_SOURCE_DIR}/include") if(FBGEMM_LIBRARY_TYPE STREQUAL "default") add_library(fbgemm $<TARGET_OBJECTS:fbgemm_avx2> diff --git a/bench/Im2ColFusedRequantizeAcc16Benchmark.cc b/bench/Im2ColFusedRequantizeAcc16Benchmark.cc index 62010ec..c24f6fa 100644 --- a/bench/Im2ColFusedRequantizeAcc16Benchmark.cc +++ b/bench/Im2ColFusedRequantizeAcc16Benchmark.cc @@ -7,6 +7,7 @@ #include <algorithm> #include <chrono> #include <cmath> +#include <iomanip> #include <iostream> #include <random> #include <vector> @@ -29,44 +30,44 @@ void performance_test() { conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1), conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0), conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1), - conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1), }; bool flush = true; @@ -79,6 +80,49 @@ void performance_test() { constexpr int NWARMUP = 4; constexpr int NITER = 10; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << "WARNING: the timer may be inaccurate when used by multiple threads." + << endl; + cout << "MB, " + << "IC, " + << "OC, " + << "IH, " + << "IW, " + << "KH, " + << "KW, " + << "stride_h, " + << "stride_w, " + << "pad_h, " + << "pad_w, " + << "Type, " + << "M, " + << "N, " + << "K, " + << "Im2Col (ms), " + << "Packing (ms), " + << "Kernel (ms), " + << "Postprocessing (ms), " + << "fbgemmPacked (ms), " + << "Total (ms), " + << "GOPS" << endl; +#else + cout << setw(8) << "MB, " + << "IC, " + << "OC, " + << "IH, " + << "IW, " + << "KH, " + << "KW, " + << "stride_h, " + << "stride_w, " + << "pad_h, " + << "pad_w, " + << "Type, " + << "M, " + << "N, " + << "K, " << setw(5) << "GOPS" << endl; +#endif + chrono::time_point<chrono::high_resolution_clock> begin, end; for (auto conv_p : shapes) { aligned_vector<float> Afp32( @@ -96,7 +140,7 @@ void performance_test() { conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0); aligned_vector<int32_t> Cint32_ref( - conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f); + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); aligned_vector<int32_t> Cint32_fb( conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); @@ -104,8 +148,6 @@ void performance_test() { aligned_vector<int32_t> Cint32_fb2( conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); - cout << conv_p.toString() << endl; - // A matrix (input activations) randFill(Afp32, 0, 5); int32_t Aint8_zero_point = 4; @@ -137,7 +179,9 @@ void performance_test() { // "B unpacked"); // packedB.printPackedMatrix("B Packed"); - double ttot = 0; + double nops = 2.0 * static_cast<double>(NITER) * MDim * NDim * KDim; + double ttot = 0.0; + string runType; vector<int32_t> row_offset_buf; row_offset_buf.resize( @@ -153,8 +197,25 @@ void performance_test() { DoNothing<int32_t, int32_t> doNothing32BitObj; memCopy<> memcopyObj(doNothing32BitObj); + runType = "FusedIm2Col"; ttot = 0; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + double im2col_time = 0.0; + double total_im2col_time = 0.0; + double total_packing_time = 0.0; + double total_computing_time = 0.0; + double total_kernel_time = 0.0; + double total_postprocessing_time = 0.0; + double total_run_time = 0.0; +#endif for (auto i = 0; i < NWARMUP + NITER; ++i) { +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + packing_time = 0.0; + computing_time = 0.0; + kernel_time = 0.0; + postprocessing_time = 0.0; + run_time = 0.0; +#endif llc_flush(llc); begin = chrono::high_resolution_clock::now(); fbgemmPacked( @@ -171,20 +232,69 @@ void performance_test() { if (i >= NWARMUP) { auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin); ttot += dur.count(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_packing_time += packing_time; + total_computing_time += computing_time; + total_kernel_time += kernel_time; + total_postprocessing_time += postprocessing_time; + total_run_time += run_time; +#endif } } - cout << fixed << "fused im2col GOPs: " - << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl; + + cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC + << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", " + << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", " + << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w + << ", "; + + cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5) + << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6) + << KDim << ", "; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << fixed << setprecision(6) << setw(8) << 0 << ", " + << total_packing_time / (double)NITER / 1e6 << ", " + << total_kernel_time / (double)NITER / 1e6 << ", " + << total_postprocessing_time / (double)NITER / 1e6 << ", " + << total_run_time / (double)NITER / 1e6 << ", " + << ttot / (double)NITER / 1e6 << ", "; +#endif + cout << setprecision(2) << nops / ttot << endl; compare_buffers(Cint32_ref.data(), Cint32_fb.data(), MDim, NDim, NDim, 5); + runType = "UnfusedIm2Col"; + row_offset_buf.resize( + PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize()); ttot = 0; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_im2col_time = 0.0; + total_packing_time = 0.0; + total_computing_time = 0.0; + total_kernel_time = 0.0; + total_postprocessing_time = 0.0; + total_run_time = 0.0; +#endif for (auto i = 0; i < NWARMUP + NITER; ++i) { +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + im2col_time = 0.0; + packing_time = 0.0; + computing_time = 0.0; + kernel_time = 0.0; + postprocessing_time = 0.0; + run_time = 0.0; +#endif llc_flush(llc); begin = chrono::high_resolution_clock::now(); im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_out.data()); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + end = chrono::high_resolution_clock::now(); + im2col_time = + chrono::duration_cast<chrono::nanoseconds>(end - begin).count(); +#endif + // printMatrix(matrix_op_t::NoTranspose, Aint8_out.data(), MDim, KDim, // KDim, "A_out after im2col unpacked"); @@ -213,8 +323,17 @@ void performance_test() { if (i >= NWARMUP) { auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin); ttot += dur.count(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_im2col_time += im2col_time; + total_packing_time += packing_time; + total_computing_time += computing_time; + total_kernel_time += kernel_time; + total_postprocessing_time += postprocessing_time; + total_run_time += run_time; +#endif } } + ((volatile char*)(llc.data())); // packedB.printPackedMatrix("bench B Packed"); @@ -225,9 +344,26 @@ void performance_test() { // printMatrix(matrix_op_t::NoTranspose, // Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32"); - cout << fixed << "unfused im2col GOPs: " - << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl; - // cout << "total time: " << ttot << " ns" << endl; + cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC + << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", " + << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", " + << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w + << ", "; + + cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5) + << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6) + << KDim << ", "; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << fixed << setprecision(6) << setw(8) + << total_im2col_time / (double)NITER / 1e6 << ", " + << total_packing_time / (double)NITER / 1e6 << ", " + << total_kernel_time / (double)NITER / 1e6 << ", " + << total_postprocessing_time / (double)NITER / 1e6 << ", " + << total_run_time / (double)NITER / 1e6 << ", " + << ttot / (double)NITER / 1e6 << ", "; +#endif + cout << setprecision(2) << nops / ttot << endl; + compare_buffers(Cint32_ref.data(), Cint32_fb2.data(), MDim, NDim, NDim, 5); } // shapes } diff --git a/bench/Im2ColFusedRequantizeAcc32Benchmark.cc b/bench/Im2ColFusedRequantizeAcc32Benchmark.cc index 9adea49..b608915 100644 --- a/bench/Im2ColFusedRequantizeAcc32Benchmark.cc +++ b/bench/Im2ColFusedRequantizeAcc32Benchmark.cc @@ -7,6 +7,7 @@ #include <algorithm> #include <chrono> #include <cmath> +#include <iomanip> #include <iostream> #include <random> #include <vector> @@ -29,44 +30,44 @@ void performance_test() { conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1), conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0), conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1), - conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), - conv_param_t( 1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), - conv_param_t( 1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1), }; bool flush = true; @@ -79,6 +80,49 @@ void performance_test() { constexpr int NWARMUP = 4; constexpr int NITER = 10; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << "WARNING: the timer may be inaccurate when used by multiple threads." + << endl; + cout << "MB, " + << "IC, " + << "OC, " + << "IH, " + << "IW, " + << "KH, " + << "KW, " + << "stride_h, " + << "stride_w, " + << "pad_h, " + << "pad_w, " + << "Type, " + << "M, " + << "N, " + << "K, " + << "Im2Col (ms), " + << "Packing (ms), " + << "Kernel (ms), " + << "Postprocessing (ms), " + << "fbgemmPacked (ms), " + << "Total (ms), " + << "GOPS" << endl; +#else + cout << setw(8) << "MB, " + << "IC, " + << "OC, " + << "IH, " + << "IW, " + << "KH, " + << "KW, " + << "stride_h, " + << "stride_w, " + << "pad_h, " + << "pad_w, " + << "Type, " + << "M, " + << "N, " + << "K, " << setw(5) << "GOPS" << endl; +#endif + chrono::time_point<chrono::high_resolution_clock> begin, end; for (auto conv_p : shapes) { aligned_vector<float> Afp32( @@ -96,7 +140,7 @@ void performance_test() { conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0); aligned_vector<int32_t> Cint32_ref( - conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f); + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); aligned_vector<int32_t> Cint32_fb( conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); @@ -104,7 +148,7 @@ void performance_test() { aligned_vector<int32_t> Cint32_fb2( conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); - cout << conv_p.toString() << endl; + // cout << conv_p.toString() << endl; // A matrix (input activations) randFill(Afp32, 0, 5); @@ -137,7 +181,9 @@ void performance_test() { // "B unpacked"); // packedB.printPackedMatrix("B Packed"); - double ttot = 0; + double nops = 2.0 * static_cast<double>(NITER) * MDim * NDim * KDim; + double ttot = 0.0; + string runType; vector<int32_t> row_offset_buf; row_offset_buf.resize( @@ -153,8 +199,25 @@ void performance_test() { DoNothing<int32_t, int32_t> doNothing32BitObj; memCopy<> memcopyObj(doNothing32BitObj); + runType = "FusedIm2Col"; ttot = 0; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + double im2col_time = 0.0; + double total_im2col_time = 0.0; + double total_packing_time = 0.0; + double total_computing_time = 0.0; + double total_kernel_time = 0.0; + double total_postprocessing_time = 0.0; + double total_run_time = 0.0; +#endif for (auto i = 0; i < NWARMUP + NITER; ++i) { +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + packing_time = 0.0; + computing_time = 0.0; + kernel_time = 0.0; + postprocessing_time = 0.0; + run_time = 0.0; +#endif llc_flush(llc); begin = chrono::high_resolution_clock::now(); fbgemmPacked( @@ -171,20 +234,67 @@ void performance_test() { if (i >= NWARMUP) { auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin); ttot += dur.count(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_packing_time += packing_time; + total_computing_time += computing_time; + total_kernel_time += kernel_time; + total_postprocessing_time += postprocessing_time; + total_run_time += run_time; +#endif } } - cout << fixed << "fused im2col GOPs: " - << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl; + + cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC + << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", " + << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", " + << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w + << ", "; + + cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5) + << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6) + << KDim << ", "; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << fixed << setprecision(6) << setw(8) << 0 << ", " + << total_packing_time / (double)NITER / 1e6 << ", " + << total_kernel_time / (double)NITER / 1e6 << ", " + << total_postprocessing_time / (double)NITER / 1e6 << ", " + << total_run_time / (double)NITER / 1e6 << ", " + << ttot / (double)NITER / 1e6 << ", "; +#endif + cout << setprecision(2) << nops / ttot << endl; compare_buffers(Cint32_ref.data(), Cint32_fb.data(), MDim, NDim, NDim, 5); + runType = "UnfusedIm2Col"; ttot = 0; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_im2col_time = 0.0; + total_packing_time = 0.0; + total_computing_time = 0.0; + total_kernel_time = 0.0; + total_postprocessing_time = 0.0; + total_run_time = 0.0; +#endif for (auto i = 0; i < NWARMUP + NITER; ++i) { +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + im2col_time = 0.0; + packing_time = 0.0; + computing_time = 0.0; + kernel_time = 0.0; + postprocessing_time = 0.0; + run_time = 0.0; +#endif llc_flush(llc); begin = chrono::high_resolution_clock::now(); im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_out.data()); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + end = chrono::high_resolution_clock::now(); + im2col_time = + chrono::duration_cast<chrono::nanoseconds>(end - begin).count(); +#endif + // printMatrix(matrix_op_t::NoTranspose, Aint8_out.data(), MDim, KDim, // KDim, "A_out after im2col unpacked"); @@ -213,6 +323,14 @@ void performance_test() { if (i >= NWARMUP) { auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin); ttot += dur.count(); +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + total_im2col_time += im2col_time; + total_packing_time += packing_time; + total_computing_time += computing_time; + total_kernel_time += kernel_time; + total_postprocessing_time += postprocessing_time; + total_run_time += run_time; +#endif } } @@ -226,9 +344,26 @@ void performance_test() { // printMatrix(matrix_op_t::NoTranspose, // Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32"); - cout << fixed << "unfused im2col GOPs: " - << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl; - // cout << "total time: " << ttot << " ns" << endl; + cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC + << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", " + << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", " + << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w + << ", "; + + cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5) + << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6) + << KDim << ", "; +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN + cout << fixed << setprecision(6) << setw(8) + << total_im2col_time / (double)NITER / 1e6 << ", " + << total_packing_time / (double)NITER / 1e6 << ", " + << total_kernel_time / (double)NITER / 1e6 << ", " + << total_postprocessing_time / (double)NITER / 1e6 << ", " + << total_run_time / (double)NITER / 1e6 << ", " + << ttot / (double)NITER / 1e6 << ", "; +#endif + cout << setprecision(2) << nops / ttot << endl; + compare_buffers(Cint32_ref.data(), Cint32_fb2.data(), MDim, NDim, NDim, 5); } // shapes } diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc index 7012289..a007685 100644 --- a/src/PackAWithIm2Col.cc +++ b/src/PackAWithIm2Col.cc @@ -10,6 +10,8 @@ #include <iostream> #include "fbgemm/Fbgemm.h" +#include <algorithm> + namespace fbgemm2 { template <typename T, typename accT> @@ -50,6 +52,7 @@ PackAWithIm2Col<T, accT>::PackAWithIm2Col( aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T))); } if (row_offset) { + rowOffsetAllocatedHere = false; row_offset_ = row_offset; } else { rowOffsetAllocatedHere = true; @@ -65,39 +68,66 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) { block.col_start, (block.col_size + row_interleave_B_ - 1) / row_interleave_B_ * row_interleave_B_}; - BaseType::packedBlock(block_p); T* out = BaseType::getBuf(); + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { int n = i / (conv_p_.OH * conv_p_.OW); int hw = i % (conv_p_.OH * conv_p_.OW); int w = hw % conv_p_.OW; int h = hw / conv_p_.OW; - for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { - int c = j % conv_p_.IC; + for (int j = block.col_start; + j < block.col_start + block.col_size + conv_p_.IC - 1; + j += conv_p_.IC) { + int j_blk_id = j / conv_p_.IC; + // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC ) + int j_blk_start = std::max(j_blk_id * conv_p_.IC, block.col_start); + int j_blk_end = std::min( + (j_blk_id + 1) * conv_p_.IC, block.col_start + block.col_size); + if (j_blk_start >= j_blk_end) { + break; + } + int rs = j / conv_p_.IC; int s = rs % conv_p_.KW; int r = rs / conv_p_.KW; int w_in = -conv_p_.pad_w + w * conv_p_.stride_w + s; int h_in = -conv_p_.pad_h + h * conv_p_.stride_h + r; - // Please note that padding for convolution should be filled with zero_pt + if (h_in < 0 || h_in >= conv_p_.IH || w_in < 0 || w_in >= conv_p_.IW) { - out[(i - block.row_start) * BaseType::blockColSize() + - (j - block.col_start)] = BaseType::zeroPoint(); + // Please note that padding for convolution should be filled with + // zero_pt + std::memset( + &out + [(i - block.row_start) * BaseType::blockColSize() + + (j_blk_start - block.col_start)], + BaseType::zeroPoint(), + sizeof(T) * (j_blk_end - j_blk_start)); } else { - out[(i - block.row_start) * BaseType::blockColSize() + - (j - block.col_start)] = sdata_ - [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC + c]; + std::memcpy( + &out + [(i - block.row_start) * BaseType::blockColSize() + + j_blk_start - block.col_start], + &sdata_ + [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC + + (j_blk_start % conv_p_.IC)], + sizeof(T) * (j_blk_end - j_blk_start)); } } // zero fill // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill. - for (int j = block.col_start + block.col_size; - j < block_p.col_start + block_p.col_size; - ++j) { - out[(i - block.row_start) * BaseType::blockColSize() + - (j - block.col_start)] = 0; + if ((block_p.col_start + block_p.col_size) - + (block.col_start + block.col_size) > + 0) { + std::memset( + &out + [(i - block.row_start) * BaseType::blockColSize() + + (block.col_size)], + 0, + sizeof(T) * + ((block_p.col_start + block_p.col_size) - + (block.col_start + block.col_size))); } } } @@ -111,7 +141,7 @@ void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) { T* out = BaseType::getBuf(); for (auto r = 0; r < BaseType::numPackedRows(); ++r) { for (auto c = 0; c < BaseType::numPackedCols(); ++c) { - T val = out[ r * BaseType::blockColSize() + c ]; + T val = out[r * BaseType::blockColSize() + c]; if (std::is_integral<T>::value) { // cast to int64 because cout doesn't print int8_t type directly std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 9aedc88..10e581f 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -8,6 +8,7 @@ #include <cassert> #include <cmath> +#include <cstring> using namespace std; @@ -226,6 +227,12 @@ int32_t clip_16bit(int32_t x) { } } +/* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function + * from caffe2/utils/math_cpu.cc + * NHWC StorageOrder/Layout + * A: NHWC: NH_0W_0 x C_0 + * Ao: NHWC: NH_1W_1 x RSC_0 + */ void im2col_ref( const conv_param_t& conv_p, const std::uint8_t* A, @@ -238,20 +245,27 @@ void im2col_ref( int h_in = -conv_p.pad_h + h * conv_p.stride_h + r; for (int s = 0; s < conv_p.KW; ++s) { int w_in = -conv_p.pad_w + w * conv_p.stride_w + s; - for (int c = 0; c < conv_p.IC; ++c) { - // Ai: NHWC: NH_0W_0 x C_0 - std::uint8_t val = - h_in < 0 || h_in >= conv_p.IH || w_in < 0 || w_in >= conv_p.IW - ? A_zero_point - : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC + - c]; - // Ao: NHWC: NH_1W_1 x RSC_0 - Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) * - conv_p.KW + - s) * - conv_p.IC + - c] = val; - } // for each c + if (h_in < 0 || h_in >= conv_p.IH || w_in < 0 || + w_in >= conv_p.IW) { + std::memset( + &Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) * + conv_p.KW + + s) * + conv_p.IC + + 0], + A_zero_point, + sizeof(uint8_t) * conv_p.IC); + } else { + std::memcpy( + &Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) * + conv_p.KW + + s) * + conv_p.IC + + 0], + &A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC + + 0], + sizeof(uint8_t) * conv_p.IC); + } } // for each s } // for each r } // for each w diff --git a/test/I8SpmdmTest.cc b/test/I8SpmdmTest.cc index c74c98a..cd8a94c 100644 --- a/test/I8SpmdmTest.cc +++ b/test/I8SpmdmTest.cc @@ -124,7 +124,9 @@ TEST_P(fbgemmSPMDMTest, TestsSpMDM) { } } +#ifdef _OPENMP #pragma omp parallel +#endif { #ifdef _OPENMP int num_threads = omp_get_num_threads(); diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc new file mode 100644 index 0000000..c09c770 --- /dev/null +++ b/test/Im2ColFusedRequantizeTest.cc @@ -0,0 +1,210 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <cmath> +#include <cstdio> + +#include <gtest/gtest.h> + +#include "bench/AlignedVec.h" +#include "bench/BenchUtils.h" +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" +#include "TestUtils.h" + +using namespace std; + +namespace fbgemm2 { + +// From Xray OCR +static vector<conv_param_t> shapes = { + // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, pad_h, pad_w + conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0), + conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0), + conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1), + conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1), + // conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t(1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1), + // conv_param_t(1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1), + // conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), + // conv_param_t(51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1), + // conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ), + conv_param_t(1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1), + // conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t(1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1), + // conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), + // conv_param_t(51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1), + conv_param_t(3, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1), + // conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ), + conv_param_t(1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1), +}; + +TEST(FBGemmIm2colTest, Acc32Test) { + for (auto conv_p : shapes) { + aligned_vector<uint8_t> Aint8( + conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0); + aligned_vector<int8_t> Bint8( + conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0); + aligned_vector<int32_t> Cint32_ref( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f); + aligned_vector<int32_t> Cint32_fb( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); + + randFill(Aint8, 0, 80); + int32_t Aint8_zero_point = 43; + randFill(Bint8, -16, 16); + + conv_ref( + conv_p, + Aint8.data(), + Aint8_zero_point, + Bint8.data(), + Cint32_ref.data()); + + int NDim = conv_p.OC; + int KDim = conv_p.KH * conv_p.KW * conv_p.IC; + + vector<int32_t> row_offset_buf; + row_offset_buf.resize( + PackAWithIm2Col<uint8_t, int32_t>::rowOffsetBufferSize()); + + PackAWithIm2Col<uint8_t, int32_t> packA( + conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data()); + + PackBMatrix<int8_t, int32_t> packedB( + matrix_op_t::NoTranspose, KDim, NDim, Bint8.data(), NDim); + + // no-op output process objects + DoNothing<int32_t, int32_t> doNothing32BitObj; + memCopy<> memcopyObj(doNothing32BitObj); + + fbgemmPacked( + packA, + packedB, + Cint32_fb.data(), + Cint32_fb.data(), + NDim, + memcopyObj, + 0, + 1); + + // correctness check + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OH; ++h) { + for (int w = 0; w < conv_p.OW; ++w) { + for (int k = 0; k < conv_p.OC; ++k) { + int32_t expected = Cint32_ref + [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k]; + int32_t actual = Cint32_fb + [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k]; + EXPECT_EQ(expected, actual) + << "Im2Col fused results differ at (" << n << ", " << h << ", " + << w << ", " << k << ")."; + } + } + } + } + + } // for each shape +} // Acc32Test + + +TEST(FBGemmIm2colTest, Acc16Test) { + for (auto conv_p : shapes) { + aligned_vector<uint8_t> Aint8( + conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0); + aligned_vector<int8_t> Bint8( + conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0); + aligned_vector<int32_t> Cint32_ref( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f); + aligned_vector<int32_t> Cint32_fb( + conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0); + + randFill(Aint8, 0, 5); + int32_t Aint8_zero_point = 4; + randFill(Bint8, -4, 4); + + conv_ref( + conv_p, + Aint8.data(), + Aint8_zero_point, + Bint8.data(), + Cint32_ref.data()); + + int NDim = conv_p.OC; + int KDim = conv_p.KH * conv_p.KW * conv_p.IC; + + vector<int32_t> row_offset_buf; + row_offset_buf.resize( + PackAWithIm2Col<uint8_t, int16_t>::rowOffsetBufferSize()); + + PackAWithIm2Col<uint8_t, int16_t> packA( + conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data()); + + PackBMatrix<int8_t, int16_t> packedB( + matrix_op_t::NoTranspose, KDim, NDim, Bint8.data(), NDim); + + // no-op output process objects + DoNothing<int32_t, int32_t> doNothing32BitObj; + memCopy<> memcopyObj(doNothing32BitObj); + + fbgemmPacked( + packA, + packedB, + Cint32_fb.data(), + Cint32_fb.data(), + NDim, + memcopyObj, + 0, + 1); + + // correctness check + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OH; ++h) { + for (int w = 0; w < conv_p.OW; ++w) { + for (int k = 0; k < conv_p.OC; ++k) { + int32_t expected = Cint32_ref + [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k]; + int32_t actual = Cint32_fb + [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k]; + EXPECT_EQ(expected, actual) + << "Im2Col fused results differ at (" << n << ", " << h << ", " + << w << ", " << k << ")."; + } + } + } + } + + } // for each shape +} // Acc16Test + + +} // namespace fbgemm2 |