Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt16
-rw-r--r--bench/Im2ColFusedRequantizeAcc16Benchmark.cc230
-rw-r--r--bench/Im2ColFusedRequantizeAcc32Benchmark.cc227
-rw-r--r--src/PackAWithIm2Col.cc60
-rw-r--r--src/RefImplementations.cc42
-rw-r--r--test/I8SpmdmTest.cc2
-rw-r--r--test/Im2ColFusedRequantizeTest.cc210
7 files changed, 657 insertions, 130 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index cfc47b5..8a477d6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,4 +1,4 @@
-cmake_minimum_required(VERSION 3.7 FATAL_ERROR)
+cmake_minimum_required(VERSION 3.5 FATAL_ERROR)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
@@ -96,17 +96,17 @@ if(NOT TARGET asmjit)
endif()
if(NOT TARGET cpuinfo)
- #Download cpuinfo from github if CPUINFO_SRC_DIR is not specified.
- if(NOT DEFINED CPUINFO_SRC_DIR)
+ #Download cpuinfo from github if CPUINFO_SOURCE_DIR is not specified.
+ if(NOT DEFINED CPUINFO_SOURCE_DIR)
message(STATUS "Downloading cpuinfo to ${FBGEMM_THIRDPARTY_DIR}/cpuinfo
- (define CPUINFO_SRC_DIR to avoid it)")
+ (define CPUINFO_SOURCE_DIR to avoid it)")
configure_file("${FBGEMM_SOURCE_DIR}/cmake/modules/DownloadCPUINFO.cmake"
"${FBGEMM_BINARY_DIR}/cpuinfo-download/CMakeLists.txt")
execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/cpuinfo-download")
execute_process(COMMAND "${CMAKE_COMMAND}" --build .
WORKING_DIRECTORY "${FBGEMM_BINARY_DIR}/cpuinfo-download")
- set(CPUINFO_SRC_DIR "${FBGEMM_THIRDPARTY_DIR}/cpuinfo" CACHE STRING
+ set(CPUINFO_SOURCE_DIR "${FBGEMM_THIRDPARTY_DIR}/cpuinfo" CACHE STRING
"cpuinfo source directory")
endif()
@@ -115,7 +115,7 @@ if(NOT TARGET cpuinfo)
set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "Do not build cpuinfo mock tests")
set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "Do not build cpuinfo benchmarks")
set(CPUINFO_LIBRARY_TYPE static)
- add_subdirectory("${CPUINFO_SRC_DIR}" "${FBGEMM_BINARY_DIR}/cpuinfo")
+ add_subdirectory("${CPUINFO_SOURCE_DIR}" "${FBGEMM_BINARY_DIR}/cpuinfo")
set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
@@ -123,13 +123,13 @@ target_include_directories(fbgemm_avx2 BEFORE
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}>
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>
PRIVATE "${ASMJIT_SRC_DIR}/src"
- PRIVATE "${CPUINFO_SRC_DIR}/include")
+ PRIVATE "${CPUINFO_SOURCE_DIR}/include")
target_include_directories(fbgemm_avx512 BEFORE
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}>
PUBLIC $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>
PRIVATE "${ASMJIT_SRC_DIR}/src"
- PRIVATE "${CPUINFO_SRC_DIR}/include")
+ PRIVATE "${CPUINFO_SOURCE_DIR}/include")
if(FBGEMM_LIBRARY_TYPE STREQUAL "default")
add_library(fbgemm $<TARGET_OBJECTS:fbgemm_avx2>
diff --git a/bench/Im2ColFusedRequantizeAcc16Benchmark.cc b/bench/Im2ColFusedRequantizeAcc16Benchmark.cc
index 62010ec..c24f6fa 100644
--- a/bench/Im2ColFusedRequantizeAcc16Benchmark.cc
+++ b/bench/Im2ColFusedRequantizeAcc16Benchmark.cc
@@ -7,6 +7,7 @@
#include <algorithm>
#include <chrono>
#include <cmath>
+#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
@@ -29,44 +30,44 @@ void performance_test() {
conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1),
};
bool flush = true;
@@ -79,6 +80,49 @@ void performance_test() {
constexpr int NWARMUP = 4;
constexpr int NITER = 10;
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << "WARNING: the timer may be inaccurate when used by multiple threads."
+ << endl;
+ cout << "MB, "
+ << "IC, "
+ << "OC, "
+ << "IH, "
+ << "IW, "
+ << "KH, "
+ << "KW, "
+ << "stride_h, "
+ << "stride_w, "
+ << "pad_h, "
+ << "pad_w, "
+ << "Type, "
+ << "M, "
+ << "N, "
+ << "K, "
+ << "Im2Col (ms), "
+ << "Packing (ms), "
+ << "Kernel (ms), "
+ << "Postprocessing (ms), "
+ << "fbgemmPacked (ms), "
+ << "Total (ms), "
+ << "GOPS" << endl;
+#else
+ cout << setw(8) << "MB, "
+ << "IC, "
+ << "OC, "
+ << "IH, "
+ << "IW, "
+ << "KH, "
+ << "KW, "
+ << "stride_h, "
+ << "stride_w, "
+ << "pad_h, "
+ << "pad_w, "
+ << "Type, "
+ << "M, "
+ << "N, "
+ << "K, " << setw(5) << "GOPS" << endl;
+#endif
+
chrono::time_point<chrono::high_resolution_clock> begin, end;
for (auto conv_p : shapes) {
aligned_vector<float> Afp32(
@@ -96,7 +140,7 @@ void performance_test() {
conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_ref(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f);
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_fb(
conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
@@ -104,8 +148,6 @@ void performance_test() {
aligned_vector<int32_t> Cint32_fb2(
conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
- cout << conv_p.toString() << endl;
-
// A matrix (input activations)
randFill(Afp32, 0, 5);
int32_t Aint8_zero_point = 4;
@@ -137,7 +179,9 @@ void performance_test() {
// "B unpacked");
// packedB.printPackedMatrix("B Packed");
- double ttot = 0;
+ double nops = 2.0 * static_cast<double>(NITER) * MDim * NDim * KDim;
+ double ttot = 0.0;
+ string runType;
vector<int32_t> row_offset_buf;
row_offset_buf.resize(
@@ -153,8 +197,25 @@ void performance_test() {
DoNothing<int32_t, int32_t> doNothing32BitObj;
memCopy<> memcopyObj(doNothing32BitObj);
+ runType = "FusedIm2Col";
ttot = 0;
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ double im2col_time = 0.0;
+ double total_im2col_time = 0.0;
+ double total_packing_time = 0.0;
+ double total_computing_time = 0.0;
+ double total_kernel_time = 0.0;
+ double total_postprocessing_time = 0.0;
+ double total_run_time = 0.0;
+#endif
for (auto i = 0; i < NWARMUP + NITER; ++i) {
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ packing_time = 0.0;
+ computing_time = 0.0;
+ kernel_time = 0.0;
+ postprocessing_time = 0.0;
+ run_time = 0.0;
+#endif
llc_flush(llc);
begin = chrono::high_resolution_clock::now();
fbgemmPacked(
@@ -171,20 +232,69 @@ void performance_test() {
if (i >= NWARMUP) {
auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
ttot += dur.count();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_packing_time += packing_time;
+ total_computing_time += computing_time;
+ total_kernel_time += kernel_time;
+ total_postprocessing_time += postprocessing_time;
+ total_run_time += run_time;
+#endif
}
}
- cout << fixed << "fused im2col GOPs: "
- << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl;
+
+ cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC
+ << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", "
+ << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", "
+ << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w
+ << ", ";
+
+ cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
+ << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
+ << KDim << ", ";
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << fixed << setprecision(6) << setw(8) << 0 << ", "
+ << total_packing_time / (double)NITER / 1e6 << ", "
+ << total_kernel_time / (double)NITER / 1e6 << ", "
+ << total_postprocessing_time / (double)NITER / 1e6 << ", "
+ << total_run_time / (double)NITER / 1e6 << ", "
+ << ttot / (double)NITER / 1e6 << ", ";
+#endif
+ cout << setprecision(2) << nops / ttot << endl;
compare_buffers(Cint32_ref.data(), Cint32_fb.data(), MDim, NDim, NDim, 5);
+ runType = "UnfusedIm2Col";
+ row_offset_buf.resize(
+ PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize());
ttot = 0;
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_im2col_time = 0.0;
+ total_packing_time = 0.0;
+ total_computing_time = 0.0;
+ total_kernel_time = 0.0;
+ total_postprocessing_time = 0.0;
+ total_run_time = 0.0;
+#endif
for (auto i = 0; i < NWARMUP + NITER; ++i) {
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ im2col_time = 0.0;
+ packing_time = 0.0;
+ computing_time = 0.0;
+ kernel_time = 0.0;
+ postprocessing_time = 0.0;
+ run_time = 0.0;
+#endif
llc_flush(llc);
begin = chrono::high_resolution_clock::now();
im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_out.data());
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ end = chrono::high_resolution_clock::now();
+ im2col_time =
+ chrono::duration_cast<chrono::nanoseconds>(end - begin).count();
+#endif
+
// printMatrix(matrix_op_t::NoTranspose, Aint8_out.data(), MDim, KDim,
// KDim, "A_out after im2col unpacked");
@@ -213,8 +323,17 @@ void performance_test() {
if (i >= NWARMUP) {
auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
ttot += dur.count();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_im2col_time += im2col_time;
+ total_packing_time += packing_time;
+ total_computing_time += computing_time;
+ total_kernel_time += kernel_time;
+ total_postprocessing_time += postprocessing_time;
+ total_run_time += run_time;
+#endif
}
}
+
((volatile char*)(llc.data()));
// packedB.printPackedMatrix("bench B Packed");
@@ -225,9 +344,26 @@ void performance_test() {
// printMatrix(matrix_op_t::NoTranspose,
// Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32");
- cout << fixed << "unfused im2col GOPs: "
- << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl;
- // cout << "total time: " << ttot << " ns" << endl;
+ cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC
+ << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", "
+ << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", "
+ << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w
+ << ", ";
+
+ cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
+ << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
+ << KDim << ", ";
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << fixed << setprecision(6) << setw(8)
+ << total_im2col_time / (double)NITER / 1e6 << ", "
+ << total_packing_time / (double)NITER / 1e6 << ", "
+ << total_kernel_time / (double)NITER / 1e6 << ", "
+ << total_postprocessing_time / (double)NITER / 1e6 << ", "
+ << total_run_time / (double)NITER / 1e6 << ", "
+ << ttot / (double)NITER / 1e6 << ", ";
+#endif
+ cout << setprecision(2) << nops / ttot << endl;
+
compare_buffers(Cint32_ref.data(), Cint32_fb2.data(), MDim, NDim, NDim, 5);
} // shapes
}
diff --git a/bench/Im2ColFusedRequantizeAcc32Benchmark.cc b/bench/Im2ColFusedRequantizeAcc32Benchmark.cc
index 9adea49..b608915 100644
--- a/bench/Im2ColFusedRequantizeAcc32Benchmark.cc
+++ b/bench/Im2ColFusedRequantizeAcc32Benchmark.cc
@@ -7,6 +7,7 @@
#include <algorithm>
#include <chrono>
#include <cmath>
+#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
@@ -29,44 +30,44 @@ void performance_test() {
conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t( 1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t( 1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1),
};
bool flush = true;
@@ -79,6 +80,49 @@ void performance_test() {
constexpr int NWARMUP = 4;
constexpr int NITER = 10;
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << "WARNING: the timer may be inaccurate when used by multiple threads."
+ << endl;
+ cout << "MB, "
+ << "IC, "
+ << "OC, "
+ << "IH, "
+ << "IW, "
+ << "KH, "
+ << "KW, "
+ << "stride_h, "
+ << "stride_w, "
+ << "pad_h, "
+ << "pad_w, "
+ << "Type, "
+ << "M, "
+ << "N, "
+ << "K, "
+ << "Im2Col (ms), "
+ << "Packing (ms), "
+ << "Kernel (ms), "
+ << "Postprocessing (ms), "
+ << "fbgemmPacked (ms), "
+ << "Total (ms), "
+ << "GOPS" << endl;
+#else
+ cout << setw(8) << "MB, "
+ << "IC, "
+ << "OC, "
+ << "IH, "
+ << "IW, "
+ << "KH, "
+ << "KW, "
+ << "stride_h, "
+ << "stride_w, "
+ << "pad_h, "
+ << "pad_w, "
+ << "Type, "
+ << "M, "
+ << "N, "
+ << "K, " << setw(5) << "GOPS" << endl;
+#endif
+
chrono::time_point<chrono::high_resolution_clock> begin, end;
for (auto conv_p : shapes) {
aligned_vector<float> Afp32(
@@ -96,7 +140,7 @@ void performance_test() {
conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_ref(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f);
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_fb(
conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
@@ -104,7 +148,7 @@ void performance_test() {
aligned_vector<int32_t> Cint32_fb2(
conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
- cout << conv_p.toString() << endl;
+ // cout << conv_p.toString() << endl;
// A matrix (input activations)
randFill(Afp32, 0, 5);
@@ -137,7 +181,9 @@ void performance_test() {
// "B unpacked");
// packedB.printPackedMatrix("B Packed");
- double ttot = 0;
+ double nops = 2.0 * static_cast<double>(NITER) * MDim * NDim * KDim;
+ double ttot = 0.0;
+ string runType;
vector<int32_t> row_offset_buf;
row_offset_buf.resize(
@@ -153,8 +199,25 @@ void performance_test() {
DoNothing<int32_t, int32_t> doNothing32BitObj;
memCopy<> memcopyObj(doNothing32BitObj);
+ runType = "FusedIm2Col";
ttot = 0;
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ double im2col_time = 0.0;
+ double total_im2col_time = 0.0;
+ double total_packing_time = 0.0;
+ double total_computing_time = 0.0;
+ double total_kernel_time = 0.0;
+ double total_postprocessing_time = 0.0;
+ double total_run_time = 0.0;
+#endif
for (auto i = 0; i < NWARMUP + NITER; ++i) {
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ packing_time = 0.0;
+ computing_time = 0.0;
+ kernel_time = 0.0;
+ postprocessing_time = 0.0;
+ run_time = 0.0;
+#endif
llc_flush(llc);
begin = chrono::high_resolution_clock::now();
fbgemmPacked(
@@ -171,20 +234,67 @@ void performance_test() {
if (i >= NWARMUP) {
auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
ttot += dur.count();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_packing_time += packing_time;
+ total_computing_time += computing_time;
+ total_kernel_time += kernel_time;
+ total_postprocessing_time += postprocessing_time;
+ total_run_time += run_time;
+#endif
}
}
- cout << fixed << "fused im2col GOPs: "
- << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl;
+
+ cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC
+ << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", "
+ << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", "
+ << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w
+ << ", ";
+
+ cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
+ << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
+ << KDim << ", ";
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << fixed << setprecision(6) << setw(8) << 0 << ", "
+ << total_packing_time / (double)NITER / 1e6 << ", "
+ << total_kernel_time / (double)NITER / 1e6 << ", "
+ << total_postprocessing_time / (double)NITER / 1e6 << ", "
+ << total_run_time / (double)NITER / 1e6 << ", "
+ << ttot / (double)NITER / 1e6 << ", ";
+#endif
+ cout << setprecision(2) << nops / ttot << endl;
compare_buffers(Cint32_ref.data(), Cint32_fb.data(), MDim, NDim, NDim, 5);
+ runType = "UnfusedIm2Col";
ttot = 0;
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_im2col_time = 0.0;
+ total_packing_time = 0.0;
+ total_computing_time = 0.0;
+ total_kernel_time = 0.0;
+ total_postprocessing_time = 0.0;
+ total_run_time = 0.0;
+#endif
for (auto i = 0; i < NWARMUP + NITER; ++i) {
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ im2col_time = 0.0;
+ packing_time = 0.0;
+ computing_time = 0.0;
+ kernel_time = 0.0;
+ postprocessing_time = 0.0;
+ run_time = 0.0;
+#endif
llc_flush(llc);
begin = chrono::high_resolution_clock::now();
im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_out.data());
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ end = chrono::high_resolution_clock::now();
+ im2col_time =
+ chrono::duration_cast<chrono::nanoseconds>(end - begin).count();
+#endif
+
// printMatrix(matrix_op_t::NoTranspose, Aint8_out.data(), MDim, KDim,
// KDim, "A_out after im2col unpacked");
@@ -213,6 +323,14 @@ void performance_test() {
if (i >= NWARMUP) {
auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
ttot += dur.count();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_im2col_time += im2col_time;
+ total_packing_time += packing_time;
+ total_computing_time += computing_time;
+ total_kernel_time += kernel_time;
+ total_postprocessing_time += postprocessing_time;
+ total_run_time += run_time;
+#endif
}
}
@@ -226,9 +344,26 @@ void performance_test() {
// printMatrix(matrix_op_t::NoTranspose,
// Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32");
- cout << fixed << "unfused im2col GOPs: "
- << static_cast<double>(NITER) * 2 * MDim * NDim * KDim / ttot << endl;
- // cout << "total time: " << ttot << " ns" << endl;
+ cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC
+ << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", "
+ << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", "
+ << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w
+ << ", ";
+
+ cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
+ << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
+ << KDim << ", ";
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << fixed << setprecision(6) << setw(8)
+ << total_im2col_time / (double)NITER / 1e6 << ", "
+ << total_packing_time / (double)NITER / 1e6 << ", "
+ << total_kernel_time / (double)NITER / 1e6 << ", "
+ << total_postprocessing_time / (double)NITER / 1e6 << ", "
+ << total_run_time / (double)NITER / 1e6 << ", "
+ << ttot / (double)NITER / 1e6 << ", ";
+#endif
+ cout << setprecision(2) << nops / ttot << endl;
+
compare_buffers(Cint32_ref.data(), Cint32_fb2.data(), MDim, NDim, NDim, 5);
} // shapes
}
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index 7012289..a007685 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -10,6 +10,8 @@
#include <iostream>
#include "fbgemm/Fbgemm.h"
+#include <algorithm>
+
namespace fbgemm2 {
template <typename T, typename accT>
@@ -50,6 +52,7 @@ PackAWithIm2Col<T, accT>::PackAWithIm2Col(
aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
}
if (row_offset) {
+ rowOffsetAllocatedHere = false;
row_offset_ = row_offset;
} else {
rowOffsetAllocatedHere = true;
@@ -65,39 +68,66 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
block.col_start,
(block.col_size + row_interleave_B_ - 1) /
row_interleave_B_ * row_interleave_B_};
-
BaseType::packedBlock(block_p);
T* out = BaseType::getBuf();
+
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
int n = i / (conv_p_.OH * conv_p_.OW);
int hw = i % (conv_p_.OH * conv_p_.OW);
int w = hw % conv_p_.OW;
int h = hw / conv_p_.OW;
- for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
- int c = j % conv_p_.IC;
+ for (int j = block.col_start;
+ j < block.col_start + block.col_size + conv_p_.IC - 1;
+ j += conv_p_.IC) {
+ int j_blk_id = j / conv_p_.IC;
+ // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC )
+ int j_blk_start = std::max(j_blk_id * conv_p_.IC, block.col_start);
+ int j_blk_end = std::min(
+ (j_blk_id + 1) * conv_p_.IC, block.col_start + block.col_size);
+ if (j_blk_start >= j_blk_end) {
+ break;
+ }
+
int rs = j / conv_p_.IC;
int s = rs % conv_p_.KW;
int r = rs / conv_p_.KW;
int w_in = -conv_p_.pad_w + w * conv_p_.stride_w + s;
int h_in = -conv_p_.pad_h + h * conv_p_.stride_h + r;
- // Please note that padding for convolution should be filled with zero_pt
+
if (h_in < 0 || h_in >= conv_p_.IH || w_in < 0 || w_in >= conv_p_.IW) {
- out[(i - block.row_start) * BaseType::blockColSize() +
- (j - block.col_start)] = BaseType::zeroPoint();
+ // Please note that padding for convolution should be filled with
+ // zero_pt
+ std::memset(
+ &out
+ [(i - block.row_start) * BaseType::blockColSize() +
+ (j_blk_start - block.col_start)],
+ BaseType::zeroPoint(),
+ sizeof(T) * (j_blk_end - j_blk_start));
} else {
- out[(i - block.row_start) * BaseType::blockColSize() +
- (j - block.col_start)] = sdata_
- [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC + c];
+ std::memcpy(
+ &out
+ [(i - block.row_start) * BaseType::blockColSize() +
+ j_blk_start - block.col_start],
+ &sdata_
+ [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC +
+ (j_blk_start % conv_p_.IC)],
+ sizeof(T) * (j_blk_end - j_blk_start));
}
}
// zero fill
// Please see the comment in PackAMatrix.cc for zero vs zero_pt fill.
- for (int j = block.col_start + block.col_size;
- j < block_p.col_start + block_p.col_size;
- ++j) {
- out[(i - block.row_start) * BaseType::blockColSize() +
- (j - block.col_start)] = 0;
+ if ((block_p.col_start + block_p.col_size) -
+ (block.col_start + block.col_size) >
+ 0) {
+ std::memset(
+ &out
+ [(i - block.row_start) * BaseType::blockColSize() +
+ (block.col_size)],
+ 0,
+ sizeof(T) *
+ ((block_p.col_start + block_p.col_size) -
+ (block.col_start + block.col_size)));
}
}
}
@@ -111,7 +141,7 @@ void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) {
T* out = BaseType::getBuf();
for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
- T val = out[ r * BaseType::blockColSize() + c ];
+ T val = out[r * BaseType::blockColSize() + c];
if (std::is_integral<T>::value) {
// cast to int64 because cout doesn't print int8_t type directly
std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index 9aedc88..10e581f 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -8,6 +8,7 @@
#include <cassert>
#include <cmath>
+#include <cstring>
using namespace std;
@@ -226,6 +227,12 @@ int32_t clip_16bit(int32_t x) {
}
}
+/* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function
+ * from caffe2/utils/math_cpu.cc
+ * NHWC StorageOrder/Layout
+ * A: NHWC: NH_0W_0 x C_0
+ * Ao: NHWC: NH_1W_1 x RSC_0
+ */
void im2col_ref(
const conv_param_t& conv_p,
const std::uint8_t* A,
@@ -238,20 +245,27 @@ void im2col_ref(
int h_in = -conv_p.pad_h + h * conv_p.stride_h + r;
for (int s = 0; s < conv_p.KW; ++s) {
int w_in = -conv_p.pad_w + w * conv_p.stride_w + s;
- for (int c = 0; c < conv_p.IC; ++c) {
- // Ai: NHWC: NH_0W_0 x C_0
- std::uint8_t val =
- h_in < 0 || h_in >= conv_p.IH || w_in < 0 || w_in >= conv_p.IW
- ? A_zero_point
- : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC +
- c];
- // Ao: NHWC: NH_1W_1 x RSC_0
- Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) *
- conv_p.KW +
- s) *
- conv_p.IC +
- c] = val;
- } // for each c
+ if (h_in < 0 || h_in >= conv_p.IH || w_in < 0 ||
+ w_in >= conv_p.IW) {
+ std::memset(
+ &Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) *
+ conv_p.KW +
+ s) *
+ conv_p.IC +
+ 0],
+ A_zero_point,
+ sizeof(uint8_t) * conv_p.IC);
+ } else {
+ std::memcpy(
+ &Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) *
+ conv_p.KW +
+ s) *
+ conv_p.IC +
+ 0],
+ &A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC +
+ 0],
+ sizeof(uint8_t) * conv_p.IC);
+ }
} // for each s
} // for each r
} // for each w
diff --git a/test/I8SpmdmTest.cc b/test/I8SpmdmTest.cc
index c74c98a..cd8a94c 100644
--- a/test/I8SpmdmTest.cc
+++ b/test/I8SpmdmTest.cc
@@ -124,7 +124,9 @@ TEST_P(fbgemmSPMDMTest, TestsSpMDM) {
}
}
+#ifdef _OPENMP
#pragma omp parallel
+#endif
{
#ifdef _OPENMP
int num_threads = omp_get_num_threads();
diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc
new file mode 100644
index 0000000..c09c770
--- /dev/null
+++ b/test/Im2ColFusedRequantizeTest.cc
@@ -0,0 +1,210 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cmath>
+#include <cstdio>
+
+#include <gtest/gtest.h>
+
+#include "bench/AlignedVec.h"
+#include "bench/BenchUtils.h"
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+#include "TestUtils.h"
+
+using namespace std;
+
+namespace fbgemm2 {
+
+// From Xray OCR
+static vector<conv_param_t> shapes = {
+ // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, pad_h, pad_w
+ conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
+ conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
+ conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1),
+ // conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t(1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1),
+ // conv_param_t(1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1),
+ // conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
+ // conv_param_t(51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
+ // conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
+ conv_param_t(1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1),
+ // conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t(1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1),
+ // conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
+ // conv_param_t(51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
+ conv_param_t(3, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
+ // conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
+ conv_param_t(1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1),
+};
+
+TEST(FBGemmIm2colTest, Acc32Test) {
+ for (auto conv_p : shapes) {
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0);
+ aligned_vector<int8_t> Bint8(
+ conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0);
+ aligned_vector<int32_t> Cint32_ref(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f);
+ aligned_vector<int32_t> Cint32_fb(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+
+ randFill(Aint8, 0, 80);
+ int32_t Aint8_zero_point = 43;
+ randFill(Bint8, -16, 16);
+
+ conv_ref(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ Bint8.data(),
+ Cint32_ref.data());
+
+ int NDim = conv_p.OC;
+ int KDim = conv_p.KH * conv_p.KW * conv_p.IC;
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithIm2Col<uint8_t, int32_t>::rowOffsetBufferSize());
+
+ PackAWithIm2Col<uint8_t, int32_t> packA(
+ conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data());
+
+ PackBMatrix<int8_t, int32_t> packedB(
+ matrix_op_t::NoTranspose, KDim, NDim, Bint8.data(), NDim);
+
+ // no-op output process objects
+ DoNothing<int32_t, int32_t> doNothing32BitObj;
+ memCopy<> memcopyObj(doNothing32BitObj);
+
+ fbgemmPacked(
+ packA,
+ packedB,
+ Cint32_fb.data(),
+ Cint32_fb.data(),
+ NDim,
+ memcopyObj,
+ 0,
+ 1);
+
+ // correctness check
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int h = 0; h < conv_p.OH; ++h) {
+ for (int w = 0; w < conv_p.OW; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int32_t expected = Cint32_ref
+ [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k];
+ int32_t actual = Cint32_fb
+ [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k];
+ EXPECT_EQ(expected, actual)
+ << "Im2Col fused results differ at (" << n << ", " << h << ", "
+ << w << ", " << k << ").";
+ }
+ }
+ }
+ }
+
+ } // for each shape
+} // Acc32Test
+
+
+TEST(FBGemmIm2colTest, Acc16Test) {
+ for (auto conv_p : shapes) {
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0);
+ aligned_vector<int8_t> Bint8(
+ conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0);
+ aligned_vector<int32_t> Cint32_ref(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f);
+ aligned_vector<int32_t> Cint32_fb(
+ conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+
+ randFill(Aint8, 0, 5);
+ int32_t Aint8_zero_point = 4;
+ randFill(Bint8, -4, 4);
+
+ conv_ref(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ Bint8.data(),
+ Cint32_ref.data());
+
+ int NDim = conv_p.OC;
+ int KDim = conv_p.KH * conv_p.KW * conv_p.IC;
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithIm2Col<uint8_t, int16_t>::rowOffsetBufferSize());
+
+ PackAWithIm2Col<uint8_t, int16_t> packA(
+ conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data());
+
+ PackBMatrix<int8_t, int16_t> packedB(
+ matrix_op_t::NoTranspose, KDim, NDim, Bint8.data(), NDim);
+
+ // no-op output process objects
+ DoNothing<int32_t, int32_t> doNothing32BitObj;
+ memCopy<> memcopyObj(doNothing32BitObj);
+
+ fbgemmPacked(
+ packA,
+ packedB,
+ Cint32_fb.data(),
+ Cint32_fb.data(),
+ NDim,
+ memcopyObj,
+ 0,
+ 1);
+
+ // correctness check
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int h = 0; h < conv_p.OH; ++h) {
+ for (int w = 0; w < conv_p.OW; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int32_t expected = Cint32_ref
+ [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k];
+ int32_t actual = Cint32_fb
+ [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k];
+ EXPECT_EQ(expected, actual)
+ << "Im2Col fused results differ at (" << n << ", " << h << ", "
+ << w << ", " << k << ").";
+ }
+ }
+ }
+ }
+
+ } // for each shape
+} // Acc16Test
+
+
+} // namespace fbgemm2