Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordskhudia <dskhudia@fb.com>2018-11-06 00:17:52 +0300
committerdskhudia <dskhudia@fb.com>2018-11-06 00:17:52 +0300
commitb96bc0bf311f7abdc83ffd3af0a485b4aef53f7c (patch)
tree2a6c276d20753abe94c526aab7b109305e3d1d78
parent14adee1ac506e067489406af689ae9b73fb581bd (diff)
generalized conv_param_t and download third party libraries in build dir
-rw-r--r--CMakeLists.txt2
-rw-r--r--README.md4
-rw-r--r--bench/Im2ColFusedRequantizeAcc16Benchmark.cc123
-rw-r--r--bench/Im2ColFusedRequantizeAcc32Benchmark.cc123
-rw-r--r--include/fbgemm/ConvUtils.h113
-rw-r--r--include/fbgemm/Fbgemm.h16
-rw-r--r--src/ExecuteKernelU8S8.cc12
-rw-r--r--src/Fbgemm.cc20
-rw-r--r--src/PackAWithIm2Col.cc143
-rw-r--r--src/PackMatrix.cc8
-rw-r--r--src/RefImplementations.cc214
-rw-r--r--src/RefImplementations.h22
-rw-r--r--test/Im2ColFusedRequantizeTest.cc346
13 files changed, 834 insertions, 312 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ad8ffd9..d957bed 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -18,8 +18,8 @@ if(FBGEMM_BUILD_TESTS)
endif()
set(FBGEMM_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
-set(FBGEMM_THIRDPARTY_DIR ${FBGEMM_SOURCE_DIR}/third-party)
set(FBGEMM_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
+set(FBGEMM_THIRDPARTY_DIR ${FBGEMM_BINARY_DIR}/third_party)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
#All the source files that either use avx2 instructions statically or JIT
diff --git a/README.md b/README.md
index 2b7efaa..224c751 100644
--- a/README.md
+++ b/README.md
@@ -44,8 +44,8 @@ is **on**. Turn it off by setting FBGEMM\_BUILD\_TESTS to off.
You can download [asmjit][1], [cpuinfo][2], [googletest][3] and set
ASMJIT\_SRC\_DIR, CPUINFO\_SRC\_DIR, GOOGLETEST\_SOURCE\_DIR respectively for
cmake to find these libraries. If any of these variables is not set, cmake will
-try to download that missing library in a folder called third-party in the
-current directory and build it using the downloaded source code.
+try to download that missing library in a folder called third\_party in the
+build directory and build it using the downloaded source code.
FBGEMM, in general, does not have any dependency on Intel MKL. However, for
performance comparison, some benchmarks use MKL functions. If MKL is found or
diff --git a/bench/Im2ColFusedRequantizeAcc16Benchmark.cc b/bench/Im2ColFusedRequantizeAcc16Benchmark.cc
index c24f6fa..ca27278 100644
--- a/bench/Im2ColFusedRequantizeAcc16Benchmark.cc
+++ b/bench/Im2ColFusedRequantizeAcc16Benchmark.cc
@@ -24,50 +24,50 @@ using namespace std;
using namespace fbgemm2;
void performance_test() {
- vector<conv_param_t> shapes = {
+ vector<conv_param_t<>> shapes = {
// MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, pad_h, pad_w
- conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
- conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
- conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t<>(1, 32, 32, {14, 14}, 1, {3, 3}, {2, 2}, {0, 0}),
+ conv_param_t<>(1, 32, 32, {14, 14}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(2, 32, 32, {14, 14}, 1, {3, 3}, {2, 2}, {0, 0}),
+ conv_param_t<>(2, 32, 32, {14, 14}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {47, 125}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {64, 125}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {66, 125}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {67, 100}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {75, 75}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {75, 76}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {75, 100}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {94, 75}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {109, 75}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {24, 63}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {33, 63}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {34, 50}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {36, 63}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {38, 38}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {38, 40}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {47, 38}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 1088, 1088, {7, 7}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(51, 1088, 1088, {7, 7}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(100, 1088, 1088, {7, 7}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {93, 250}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {128, 250}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {133, 200}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {150, 150}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {150, 151}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {150, 158}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {188, 150}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {225, 150}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {47, 125}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {64, 125}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {66, 125}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {67, 100}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {75, 75}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {75, 76}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {94, 75}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {14, 14}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(51, 544, 544, {14, 14}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(100, 544, 544, {14, 14}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 8, 8, {4, 4}, 1, {3, 3}, {1, 1}, {1, 1}),
};
bool flush = true;
@@ -126,27 +126,28 @@ void performance_test() {
chrono::time_point<chrono::high_resolution_clock> begin, end;
for (auto conv_p : shapes) {
aligned_vector<float> Afp32(
- conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0.0f);
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0.0f);
aligned_vector<uint8_t> Aint8(
- conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0);
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0);
aligned_vector<uint8_t> Aint8_out(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.KH * conv_p.KW * conv_p.IC,
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.K[0] *
+ conv_p.K[1] * conv_p.IC,
0);
aligned_vector<float> Bfp32(
- conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0.0f);
+ conv_p.K[0] * conv_p.K[1] * conv_p.IC * conv_p.OC, 0.0f);
aligned_vector<int8_t> Bint8(
- conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0);
+ conv_p.K[0] * conv_p.K[1] * conv_p.IC * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_ref(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_fb(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_fb2(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
// A matrix (input activations)
randFill(Afp32, 0, 5);
@@ -171,9 +172,9 @@ void performance_test() {
Cint32_ref.data());
// matrix dimensions after im2col
- int MDim = conv_p.MB * conv_p.OH * conv_p.OW;
+ int MDim = conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1];
int NDim = conv_p.OC;
- int KDim = conv_p.KH * conv_p.KW * conv_p.IC;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
// printMatrix(matrix_op_t::NoTranspose, Bint8.data(), KDim, NDim, NDim,
// "B unpacked");
@@ -243,10 +244,10 @@ void performance_test() {
}
cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC
- << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", "
- << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", "
- << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w
- << ", ";
+ << ", " << conv_p.IN_DIM[0] << ", " << conv_p.IN_DIM[1] << ", "
+ << conv_p.G << ", " << conv_p.K[0] << ", " << conv_p.K[1] << ", "
+ << conv_p.stride[0] << ", " << conv_p.stride[1] << ", "
+ << conv_p.pad[0] << ", " << conv_p.pad[1] << ", ";
cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
<< setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
@@ -345,10 +346,10 @@ void performance_test() {
// Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32");
cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC
- << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", "
- << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", "
- << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w
- << ", ";
+ << ", " << conv_p.IN_DIM[0] << ", " << conv_p.IN_DIM[1] << ", "
+ << conv_p.G << ", " << conv_p.K[0] << ", " << conv_p.K[1] << ", "
+ << conv_p.stride[0] << ", " << conv_p.stride[1] << ", "
+ << conv_p.pad[0] << ", " << conv_p.pad[1] << ", ";
cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
<< setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
diff --git a/bench/Im2ColFusedRequantizeAcc32Benchmark.cc b/bench/Im2ColFusedRequantizeAcc32Benchmark.cc
index b608915..8cce235 100644
--- a/bench/Im2ColFusedRequantizeAcc32Benchmark.cc
+++ b/bench/Im2ColFusedRequantizeAcc32Benchmark.cc
@@ -24,50 +24,50 @@ using namespace std;
using namespace fbgemm2;
void performance_test() {
- vector<conv_param_t> shapes = {
+ vector<conv_param_t<>> shapes = {
// MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, pad_h, pad_w
- conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
- conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
- conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t<>(1, 32, 32, {14, 14}, 1, {3, 3}, {2, 2}, {0, 0}),
+ conv_param_t<>(1, 32, 32, {14, 14}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(2, 32, 32, {14, 14}, 1, {3, 3}, {2, 2}, {0, 0}),
+ conv_param_t<>(2, 32, 32, {14, 14}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {47, 125}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {64, 125}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {66, 125}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {67, 100}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {75, 75}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {75, 76}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {75, 100}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {94, 75}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {109, 75}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {24, 63}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {33, 63}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {34, 50}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {36, 63}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {38, 38}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {38, 40}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {47, 38}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 1088, 1088, {7, 7}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(51, 1088, 1088, {7, 7}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(100, 1088, 1088, {7, 7}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {93, 250}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {128, 250}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {133, 200}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {150, 150}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {150, 151}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {150, 158}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {188, 150}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 248, 248, {225, 150}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {47, 125}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {64, 125}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {66, 125}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {67, 100}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {75, 75}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {75, 76}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 272, 272, {94, 75}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {14, 14}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(51, 544, 544, {14, 14}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(100, 544, 544, {14, 14}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 8, 8, {4, 4}, 1, {3, 3}, {1, 1}, {1, 1}),
};
bool flush = true;
@@ -126,27 +126,28 @@ void performance_test() {
chrono::time_point<chrono::high_resolution_clock> begin, end;
for (auto conv_p : shapes) {
aligned_vector<float> Afp32(
- conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0.0f);
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0.0f);
aligned_vector<uint8_t> Aint8(
- conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0);
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0);
aligned_vector<uint8_t> Aint8_out(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.KH * conv_p.KW * conv_p.IC,
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.K[0] *
+ conv_p.K[1] * conv_p.IC,
0);
aligned_vector<float> Bfp32(
- conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0.0f);
+ conv_p.K[0] * conv_p.K[1] * conv_p.IC * conv_p.OC, 0.0f);
aligned_vector<int8_t> Bint8(
- conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0);
+ conv_p.K[0] * conv_p.K[1] * conv_p.IC * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_ref(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_fb(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_fb2(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
// cout << conv_p.toString() << endl;
@@ -173,9 +174,9 @@ void performance_test() {
Cint32_ref.data());
// matrix dimensions after im2col
- int MDim = conv_p.MB * conv_p.OH * conv_p.OW;
+ int MDim = conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1];
int NDim = conv_p.OC;
- int KDim = conv_p.KH * conv_p.KW * conv_p.IC;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
// printMatrix(matrix_op_t::NoTranspose, Bint8.data(), KDim, NDim, NDim,
// "B unpacked");
@@ -245,10 +246,10 @@ void performance_test() {
}
cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC
- << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", "
- << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", "
- << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w
- << ", ";
+ << ", " << conv_p.IN_DIM[0] << ", " << conv_p.IN_DIM[1] << ", "
+ << conv_p.G << ", " << conv_p.K[0] << ", " << conv_p.K[1] << ", "
+ << conv_p.stride[0] << ", " << conv_p.stride[1] << ", "
+ << conv_p.pad[0] << ", " << conv_p.pad[1] << ", ";
cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
<< setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
@@ -345,10 +346,10 @@ void performance_test() {
// Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32");
cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC
- << ", " << conv_p.IH << ", " << conv_p.IW << ", " << conv_p.G << ", "
- << conv_p.KH << ", " << conv_p.KW << ", " << conv_p.stride_h << ", "
- << conv_p.stride_w << ", " << conv_p.pad_h << ", " << conv_p.pad_w
- << ", ";
+ << ", " << conv_p.IN_DIM[0] << ", " << conv_p.IN_DIM[1] << ", "
+ << conv_p.G << ", " << conv_p.K[0] << ", " << conv_p.K[1] << ", "
+ << conv_p.stride[0] << ", " << conv_p.stride[1] << ", "
+ << conv_p.pad[0] << ", " << conv_p.pad[1] << ", ";
cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
<< setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h
index 02e862f..438807f 100644
--- a/include/fbgemm/ConvUtils.h
+++ b/include/fbgemm/ConvUtils.h
@@ -5,6 +5,8 @@
* LICENSE file in the root directory of this source tree.
*/
#pragma once
+
+#include <array>
#include <string>
namespace fbgemm2 {
@@ -12,27 +14,21 @@ namespace fbgemm2 {
/**
* @brief A struct to conveniently store all convolution parameters.
*/
+template <int SPATIAL_DIM = 2>
struct conv_param_t {
int MB; ///< Mini Batch size
int IC; ///< Number of Input Channels
int OC; ///< Number of Output Channels
- int IH; ///< Input Image Height
- int IW; ///< Input Image Width
+ std::array<int, SPATIAL_DIM> IN_DIM; ///< Input Image Dimension
int G; ///< Number of Groups
- int KH; ///< Filter (Kernel) Height
- int KW; ///< Filter (Kernel) Width
- int stride_h; ///< Stride in Height Dimension
- int stride_w; ///< Stride in Width Dimension
- int pad_h; ///< Padding in Height Dimension (top and bottom)
- int pad_w; ///< Padding in Width Dimension (left and right)
- int dilation_h; ///< Kernel dilation in Height Dimension
- int dilation_w; ///< Kernel dilation in Width Dimension
+ std::array<int, SPATIAL_DIM> K; ///< Filter (Kernel) dimensions
+ std::array<int, SPATIAL_DIM> stride; //< Strides
+ std::array<int, SPATIAL_DIM> pad; //< Padding (assume symmetric padding)
+ std::array<int, SPATIAL_DIM> dilation; //< Kernel dilation
// The following are derived parameters
- int OH; ///< Output Image Height
- int OW; ///< Output Image Width
- int IHP; ///< Input Height Padded
- int IWP; ///< Input Width Padded
+ std::array<int, SPATIAL_DIM> OUT_DIM; //< Output Image Dimension
+ std::array<int, SPATIAL_DIM> IN_DIMP; //< Input Image Dimension Padded
/**
* @brief Constructor for initializing the convolution parameters.
@@ -42,52 +38,79 @@ struct conv_param_t {
int mb,
int ic,
int oc,
- int ih,
- int iw,
- int g = 1,
- int kh = 3,
- int kw = 3,
- int strd_h = 1,
- int strd_w = 1,
- int pd_h = 1,
- int pd_w = 1)
+ std::array<int, SPATIAL_DIM> in_dim,
+ int g,
+ std::array<int, SPATIAL_DIM> k,
+ std::array<int, SPATIAL_DIM> strd,
+ std::array<int, SPATIAL_DIM> pd)
: MB(mb),
IC(ic),
OC(oc),
- IH(ih),
- IW(iw),
+ IN_DIM(in_dim),
G(g),
- KH(kh),
- KW(kw),
- stride_h(strd_h),
- stride_w(strd_w),
- pad_h(pd_h),
- pad_w(pd_w),
- dilation_h(1),
- dilation_w(1) {
- IHP = IH + 2 * pad_h;
- IWP = IW + 2 * pad_w;
- OH = (IHP - KH) / stride_h + 1;
- OW = (IWP - KW) / stride_w + 1;
+ K(k),
+ stride(strd),
+ pad(pd) {
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ dilation[d] = 1;
+ IN_DIMP[d] = IN_DIM[d] + 2 * pad[d];
+ OUT_DIM[d] = (IN_DIMP[d] - K[d]) / stride[d] + 1;
+ }
}
/**
* @brief Helper function to get convolution parameters as string.
*/
std::string toString() const {
+ std::string dim_string[3] = {"T", "H", "W"};
+
std::string out = "";
out += "MB:" + std::to_string(MB) + ", ";
out += "IC:" + std::to_string(IC) + ", ";
out += "OC:" + std::to_string(OC) + ", ";
- out += "IH:" + std::to_string(IH) + ", ";
- out += "IW:" + std::to_string(IW) + ", ";
+ if (SPATIAL_DIM <= 3) {
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "I" + dim_string[3 - SPATIAL_DIM + d] + ":" +
+ std::to_string(IN_DIM[d]) + ", ";
+ }
+ } else {
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "I" + std::to_string(d) + ":" +
+ std::to_string(IN_DIM[d]) + ", ";
+ }
+ }
out += "G:" + std::to_string(G) + ", ";
- out += "KH:" + std::to_string(KH) + ", ";
- out += "KW:" + std::to_string(KW) + ", ";
- out += "stride_h:" + std::to_string(stride_h) + ", ";
- out += "stride_w:" + std::to_string(stride_w) + ", ";
- out += "pad_h:" + std::to_string(pad_h) + ", ";
- out += "pad_w:" + std::to_string(pad_w);
+ if (SPATIAL_DIM <= 3) {
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "K" + dim_string[3 - SPATIAL_DIM + d] + ":" +
+ std::to_string(K[d]) + ", ";
+ }
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "stride_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
+ std::to_string(stride[d]) + ", ";
+ }
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "pad_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
+ std::to_string(pad[d]);
+ if (d < SPATIAL_DIM - 1) {
+ out += ", ";
+ }
+ }
+ } else {
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "K" + std::to_string(d) + ":" + std::to_string(K[d]) + ", ";
+ }
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "stride_" + std::to_string(d) + ":" + std::to_string(stride[d]) +
+ ", ";
+ }
+ for (int d = 0; d < SPATIAL_DIM; ++d) {
+ out += "pad_" + std::to_string(d) + ":" + std::to_string(pad[d]);
+ if (d < SPATIAL_DIM - 1) {
+ out += ", ";
+ }
+ }
+ }
return out;
}
};
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 2f9ddc7..23412be 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -438,11 +438,11 @@ class PackBMatrix final : public PackMatrix<PackBMatrix<T, accT>, T, accT> {
* Im2col is fused with packing here. The source matrix is already
* quantized.
*/
-template <typename T, typename accT = std::int32_t>
-class PackAWithIm2Col final
- : public PackMatrix<PackAWithIm2Col<T, accT>, T, accT> {
+template <typename T, typename accT = std::int32_t, int SPATIAL_DIM = 2>
+class PackAWithIm2Col
+ : public PackMatrix<PackAWithIm2Col<T, accT, SPATIAL_DIM>, T, accT> {
public:
- using This = PackAWithIm2Col<T, accT>;
+ using This = PackAWithIm2Col<T, accT, SPATIAL_DIM>;
using BaseType = PackMatrix<This, T, accT>;
using inpType = T;
using accType = accT;
@@ -452,7 +452,7 @@ class PackAWithIm2Col final
* TODO: Currently only groups == 1 supported
*/
PackAWithIm2Col(
- const conv_param_t& conv_param,
+ const conv_param_t<SPATIAL_DIM>& conv_param,
const T* sdata,
inpType* pmat = nullptr,
std::int32_t zero_pt = 0,
@@ -487,7 +487,7 @@ class PackAWithIm2Col final
}
private:
- const conv_param_t& conv_p_;
+ const conv_param_t<SPATIAL_DIM>& conv_p_;
const T* sdata_;
std::int32_t* row_offset_;
bool rowOffsetAllocatedHere;
@@ -942,8 +942,8 @@ template <
typename outT,
typename processOutputType>
void convDepthwiseSeparable(
- const conv_param_t& conv_param_dw,
- const conv_param_t& conv_param_1x1,
+ const conv_param_t<>& conv_param_dw,
+ const conv_param_t<>& conv_param_1x1,
packingAMatrix& packdw,
packingBMatrix& packed_1x1,
outT* out,
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index 5145869..e091a87 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -322,6 +322,12 @@ template class ExecuteKernel<
memCopy<>>;
template class ExecuteKernel<
+ PackAWithIm2Col<uint8_t, int16_t, 3>,
+ PackBMatrix<int8_t, int16_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
PackAWithRowOffset<uint8_t, int32_t>,
PackBMatrix<int8_t, int32_t>,
int32_t,
@@ -334,6 +340,12 @@ template class ExecuteKernel<
memCopy<>>;
template class ExecuteKernel<
+ PackAWithIm2Col<uint8_t, int32_t, 3>,
+ PackBMatrix<int8_t, int32_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
PackAWithQuantRowOffset<uint8_t, int32_t>,
PackBMatrix<int8_t, int32_t>,
int32_t,
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index f3bac97..9195a05 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -236,6 +236,16 @@ template void fbgemmPacked(
int num_threads);
template void fbgemmPacked(
+ PackMatrix<PackAWithIm2Col<uint8_t, int32_t, 3>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>&
packA,
PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
@@ -341,6 +351,16 @@ template void fbgemmPacked(
int num_threads);
template void fbgemmPacked(
+ PackMatrix<PackAWithIm2Col<uint8_t, int16_t, 3>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA,
PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
int32_t* C,
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index e067a3e..8dde696 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -4,26 +4,37 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
-#include <cpuinfo.h>
+#include <algorithm>
#include <cassert>
#include <iomanip>
#include <iostream>
-#include "fbgemm/Fbgemm.h"
+#include <numeric>
+#include <cpuinfo.h>
-#include <algorithm>
+#include "fbgemm/Fbgemm.h"
namespace fbgemm2 {
-template <typename T, typename accT>
-PackAWithIm2Col<T, accT>::PackAWithIm2Col(
- const conv_param_t& conv_p,
+template <typename T, typename accT, int SPATIAL_DIM>
+PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
+ const conv_param_t<SPATIAL_DIM>& conv_p,
const T* sdata,
inpType* pmat,
int32_t zero_pt,
int32_t* row_offset)
- : PackMatrix<PackAWithIm2Col<T, accT>, T, accT>(
- conv_p.MB * conv_p.OH * conv_p.OW,
- conv_p.KH * conv_p.KW * conv_p.IC,
+ : PackMatrix<PackAWithIm2Col<T, accT, SPATIAL_DIM>, T, accT>(
+ conv_p.MB *
+ std::accumulate(
+ conv_p.OUT_DIM.begin(),
+ conv_p.OUT_DIM.end(),
+ 1,
+ std::multiplies<int>()),
+ std::accumulate(
+ conv_p.K.begin(),
+ conv_p.K.end(),
+ 1,
+ std::multiplies<int>()) *
+ conv_p.IC,
pmat,
zero_pt),
conv_p_(conv_p),
@@ -62,8 +73,8 @@ PackAWithIm2Col<T, accT>::PackAWithIm2Col(
}
}
-template <typename T, typename accT>
-void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
+template <typename T, typename accT, int SPATIAL_DIM>
+void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
block_type_t block_p = {block.row_start,
block.row_size,
block.col_start,
@@ -72,11 +83,87 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
BaseType::packedBlock(block_p);
T* out = BaseType::getBuf();
+ if (SPATIAL_DIM == 3) { // static if
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int n =
+ i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]);
+ int thw =
+ i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]);
+ int w = thw % conv_p_.OUT_DIM[2];
+ int h = thw / conv_p_.OUT_DIM[2] % conv_p_.OUT_DIM[1];
+ int t = thw / conv_p_.OUT_DIM[2] / conv_p_.OUT_DIM[1];
+ for (int j = block.col_start;
+ j < block.col_start + block.col_size + conv_p_.IC - 1;
+ j += conv_p_.IC) {
+ int j_blk_id = j / conv_p_.IC;
+ // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC )
+ int j_blk_start = std::max(j_blk_id * conv_p_.IC, block.col_start);
+ int j_blk_end = std::min(
+ (j_blk_id + 1) * conv_p_.IC, block.col_start + block.col_size);
+ if (j_blk_start >= j_blk_end) {
+ break;
+ }
+
+ int qrs = j / conv_p_.IC;
+ int s = qrs % conv_p_.K[2];
+ int r = qrs / conv_p_.K[2] % conv_p_.K[1];
+ int q = qrs / conv_p_.K[2] / conv_p_.K[1];
+
+ int t_in = -conv_p_.pad[0] + t * conv_p_.stride[0] + q;
+ int h_in = -conv_p_.pad[1] + h * conv_p_.stride[1] + r;
+ int w_in = -conv_p_.pad[2] + w * conv_p_.stride[2] + s;
+
+ if (t_in < 0 || t_in >= conv_p_.IN_DIM[0] || h_in < 0 ||
+ h_in >= conv_p_.IN_DIM[1] || w_in < 0 ||
+ w_in >= conv_p_.IN_DIM[2]) {
+ // Please note that padding for convolution should be filled with
+ // zero_pt
+ std::memset(
+ &out
+ [(i - block.row_start) * BaseType::blockColSize() +
+ (j_blk_start - block.col_start)],
+ BaseType::zeroPoint(),
+ sizeof(T) * (j_blk_end - j_blk_start));
+ } else {
+ std::memcpy(
+ &out
+ [(i - block.row_start) * BaseType::blockColSize() +
+ j_blk_start - block.col_start],
+ &sdata_
+ [(((n * conv_p_.IN_DIM[0] + t_in) * conv_p_.IN_DIM[1] +
+ h_in) *
+ conv_p_.IN_DIM[2] +
+ w_in) *
+ conv_p_.IC +
+ (j_blk_start % conv_p_.IC)],
+ sizeof(T) * (j_blk_end - j_blk_start));
+ }
+ }
+ // zero fill
+ // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill.
+ if ((block_p.col_start + block_p.col_size) -
+ (block.col_start + block.col_size) >
+ 0) {
+ std::memset(
+ &out
+ [(i - block.row_start) * BaseType::blockColSize() +
+ (block.col_size)],
+ 0,
+ sizeof(T) *
+ ((block_p.col_start + block_p.col_size) -
+ (block.col_start + block.col_size)));
+ }
+ }
+ return;
+ }
+
+ assert(SPATIAL_DIM == 2 && "unsupported conv dimension");
+
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
- int n = i / (conv_p_.OH * conv_p_.OW);
- int hw = i % (conv_p_.OH * conv_p_.OW);
- int w = hw % conv_p_.OW;
- int h = hw / conv_p_.OW;
+ int n = i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]);
+ int hw = i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]);
+ int w = hw % conv_p_.OUT_DIM[1];
+ int h = hw / conv_p_.OUT_DIM[1];
for (int j = block.col_start;
j < block.col_start + block.col_size + conv_p_.IC - 1;
j += conv_p_.IC) {
@@ -90,13 +177,14 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
}
int rs = j / conv_p_.IC;
- int s = rs % conv_p_.KW;
- int r = rs / conv_p_.KW;
+ int s = rs % conv_p_.K[1];
+ int r = rs / conv_p_.K[1];
- int w_in = -conv_p_.pad_w + w * conv_p_.stride_w + s;
- int h_in = -conv_p_.pad_h + h * conv_p_.stride_h + r;
+ int h_in = -conv_p_.pad[0] + h * conv_p_.stride[0] + r;
+ int w_in = -conv_p_.pad[1] + w * conv_p_.stride[1] + s;
- if (h_in < 0 || h_in >= conv_p_.IH || w_in < 0 || w_in >= conv_p_.IW) {
+ if (h_in < 0 || h_in >= conv_p_.IN_DIM[0] || w_in < 0 ||
+ w_in >= conv_p_.IN_DIM[1]) {
// Please note that padding for convolution should be filled with
// zero_pt
std::memset(
@@ -111,7 +199,8 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
[(i - block.row_start) * BaseType::blockColSize() +
j_blk_start - block.col_start],
&sdata_
- [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC +
+ [((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] + w_in) *
+ conv_p_.IC +
(j_blk_start % conv_p_.IC)],
sizeof(T) * (j_blk_end - j_blk_start));
}
@@ -133,8 +222,9 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
}
}
-template <typename T, typename accT>
-void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) {
+template <typename T, typename accT, int SPATIAL_DIM>
+void PackAWithIm2Col<T, accT, SPATIAL_DIM>::printPackedMatrix(
+ std::string name) {
std::cout << name << ":"
<< "[" << BaseType::numPackedRows() << ", "
<< BaseType::numPackedCols() << "]" << std::endl;
@@ -155,8 +245,8 @@ void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) {
std::cout << std::endl;
}
-template <typename T, typename accT>
-int PackAWithIm2Col<T, accT>::rowOffsetBufferSize() {
+template <typename T, typename accT, int SPATIAL_DIM>
+int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize() {
if (cpuinfo_initialize()) {
if (cpuinfo_has_x86_avx512f()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
@@ -174,4 +264,7 @@ int PackAWithIm2Col<T, accT>::rowOffsetBufferSize() {
template class PackAWithIm2Col<uint8_t, int32_t>;
template class PackAWithIm2Col<uint8_t, int16_t>;
+template class PackAWithIm2Col<uint8_t, int32_t, 3>;
+template class PackAWithIm2Col<uint8_t, int16_t, 3>;
+
} // namespace fbgemm2
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
index 85000ac..37b4e88 100644
--- a/src/PackMatrix.cc
+++ b/src/PackMatrix.cc
@@ -64,6 +64,10 @@ template class PackMatrix<
int32_t>;
template class PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>;
+template class PackMatrix<
+ PackAWithIm2Col<uint8_t, int32_t, 3>,
+ uint8_t,
+ int32_t>;
template class PackMatrix<
PackAWithQuantRowOffset<uint8_t, int32_t>,
@@ -74,6 +78,10 @@ template class PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>;
// int16 accumulation
template class PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>;
+template class PackMatrix<
+ PackAWithIm2Col<uint8_t, int16_t, 3>,
+ uint8_t,
+ int16_t>;
template class PackMatrix<
PackAWithRowOffset<uint8_t, int16_t>,
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index 6bf2d65..4b919c1 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -219,10 +219,10 @@ void spmdm_ref(
}
int32_t clip_16bit(int32_t x) {
- if (x > std::numeric_limits<int16_t>::max()) {
- return std::min<int>(std::numeric_limits<int16_t>::max(), x);
- } else if (x < std::numeric_limits<int16_t>::min()) {
- return std::max<int>(std::numeric_limits<int16_t>::min(), x);
+ if (x > numeric_limits<int16_t>::max()) {
+ return std::min<int>(numeric_limits<int16_t>::max(), x);
+ } else if (x < numeric_limits<int16_t>::min()) {
+ return std::max<int>(numeric_limits<int16_t>::min(), x);
} else {
return x;
}
@@ -235,36 +235,38 @@ int32_t clip_16bit(int32_t x) {
* Ao: NHWC: NH_1W_1 x RSC_0
*/
void im2col_ref(
- const conv_param_t& conv_p,
- const std::uint8_t* A,
- std::int32_t A_zero_point,
- std::uint8_t* Ao) {
+ const conv_param_t<>& conv_p,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ uint8_t* Ao) {
for (int n = 0; n < conv_p.MB; ++n) {
- for (int h = 0; h < conv_p.OH; ++h) {
- for (int w = 0; w < conv_p.OW; ++w) {
- for (int r = 0; r < conv_p.KH; ++r) {
- int h_in = -conv_p.pad_h + h * conv_p.stride_h + r;
- for (int s = 0; s < conv_p.KW; ++s) {
- int w_in = -conv_p.pad_w + w * conv_p.stride_w + s;
- if (h_in < 0 || h_in >= conv_p.IH || w_in < 0 ||
- w_in >= conv_p.IW) {
- std::memset(
- &Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) *
- conv_p.KW +
+ for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) {
+ for (int r = 0; r < conv_p.K[0]; ++r) {
+ int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
+ for (int s = 0; s < conv_p.K[1]; ++s) {
+ int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
+ if (h_in < 0 || h_in >= conv_p.IN_DIM[0] || w_in < 0 ||
+ w_in >= conv_p.IN_DIM[1]) {
+ memset(
+ &Ao[((((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.K[0] +
+ r) *
+ conv_p.K[1] +
s) *
- conv_p.IC +
- 0],
+ conv_p.IC],
A_zero_point,
sizeof(uint8_t) * conv_p.IC);
} else {
- std::memcpy(
- &Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) *
- conv_p.KW +
+ memcpy(
+ &Ao[((((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.K[0] +
+ r) *
+ conv_p.K[1] +
s) *
- conv_p.IC +
- 0],
- &A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC +
- 0],
+ conv_p.IC],
+ &A[((n * conv_p.IN_DIM[0] + h_in) * conv_p.IN_DIM[1] + w_in) *
+ conv_p.IC],
sizeof(uint8_t) * conv_p.IC);
}
} // for each s
@@ -274,44 +276,168 @@ void im2col_ref(
} // for each n
}
+/* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function
+ * from caffe2/utils/math_cpu.cc
+ * NHWC StorageOrder/Layout
+ * A: NHWC: NT_0H_0W_0 x C_0
+ * Ao: NHWC: NT_1H_1W_1 x QRSC_0
+ */
+void im2col3d_ref(
+ const conv_param_t<3>& conv_p,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ uint8_t* Ao) {
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int t = 0; t < conv_p.OUT_DIM[0]; ++t) {
+ for (int h = 0; h < conv_p.OUT_DIM[1]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[2]; ++w) {
+ for (int q = 0; q < conv_p.K[0]; ++q) {
+ int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q;
+ for (int r = 0; r < conv_p.K[1]; ++r) {
+ int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r;
+ for (int s = 0; s < conv_p.K[2]; ++s) {
+ int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s;
+ if (t_in < 0 || t_in >= conv_p.IN_DIM[0] || h_in < 0 ||
+ h_in >= conv_p.IN_DIM[1] || w_in < 0 ||
+ w_in >= conv_p.IN_DIM[2]) {
+ memset(
+ &Ao[((((((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] +
+ h) *
+ conv_p.OUT_DIM[2] +
+ w) *
+ conv_p.K[0] +
+ q) *
+ conv_p.K[1] +
+ r) *
+ conv_p.K[2] +
+ s) *
+ conv_p.IC],
+ A_zero_point,
+ sizeof(uint8_t) * conv_p.IC);
+ } else {
+ memcpy(
+ &Ao[((((((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] +
+ h) *
+ conv_p.OUT_DIM[2] +
+ w) *
+ conv_p.K[0] +
+ q) *
+ conv_p.K[1] +
+ r) *
+ conv_p.K[2] +
+ s) *
+ conv_p.IC],
+ &A[(((n * conv_p.IN_DIM[0] + t_in) * conv_p.IN_DIM[1] +
+ h_in) *
+ conv_p.IN_DIM[2] +
+ w_in) *
+ conv_p.IC],
+ sizeof(uint8_t) * conv_p.IC);
+ }
+ } // for each s
+ } // for each r
+ } // for each q
+ } // for each w
+ } // for each h
+ } // for each t
+ } // for each n
+}
+
void conv_ref(
- const conv_param_t& conv_p,
- const std::uint8_t* A,
- std::int32_t A_zero_point,
- const std::int8_t* B,
- std::int32_t* C) {
+ const conv_param_t<>& conv_p,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* B,
+ int32_t* C) {
// filters are assumed to be in RSCK format
assert(conv_p.G == 1 && "Groups != 1 not supported yet");
for (int n = 0; n < conv_p.MB; ++n) {
- for (int h = 0; h < conv_p.OH; ++h) {
- for (int w = 0; w < conv_p.OW; ++w) {
+ for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) {
for (int k = 0; k < conv_p.OC; ++k) {
int sum = 0;
- for (int r = 0; r < conv_p.KH; ++r) {
- int h_in = -conv_p.pad_h + h * conv_p.stride_h + r;
- for (int s = 0; s < conv_p.KW; ++s) {
- int w_in = -conv_p.pad_w + w * conv_p.stride_w + s;
+ for (int r = 0; r < conv_p.K[0]; ++r) {
+ int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
+ for (int s = 0; s < conv_p.K[1]; ++s) {
+ int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
for (int c = 0; c < conv_p.IC; ++c) {
- int a = h_in < 0 || h_in >= conv_p.IH || w_in < 0 ||
- w_in >= conv_p.IW
+ int a = h_in < 0 || h_in >= conv_p.IN_DIM[0] || w_in < 0 ||
+ w_in >= conv_p.IN_DIM[1]
? A_zero_point
- : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) *
+ : A[((n * conv_p.IN_DIM[0] + h_in) * conv_p.IN_DIM[1] +
+ w_in) *
conv_p.IC +
c];
int b =
- B[((r * conv_p.KW + s) * conv_p.IC + c) * conv_p.OC + k];
+ B[((r * conv_p.K[1] + s) * conv_p.IC + c) * conv_p.OC + k];
sum += a * b;
} // for each c
} // for each s
} // for each r
- C[((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k] = sum;
+ C[((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) * conv_p.OC +
+ k] = sum;
} // for each k
} // for each w
} // for each h
} // for each n
}
+void conv3d_ref(
+ const conv_param_t<3>& conv_p,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* B,
+ int32_t* C) {
+ // filters are assumed to be in RSCK format
+ assert(conv_p.G == 1 && "Groups != 1 not supported yet");
+
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int t = 0; t < conv_p.OUT_DIM[0]; ++t) {
+ for (int h = 0; h < conv_p.OUT_DIM[1]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[2]; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int sum = 0;
+ for (int q = 0; q < conv_p.K[0]; ++q) {
+ int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q;
+ for (int r = 0; r < conv_p.K[1]; ++r) {
+ int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r;
+ for (int s = 0; s < conv_p.K[2]; ++s) {
+ int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s;
+ for (int c = 0; c < conv_p.IC; ++c) {
+ int a = t_in < 0 || t_in >= conv_p.IN_DIM[0] || h_in < 0 ||
+ h_in >= conv_p.IN_DIM[1] || w_in < 0 ||
+ w_in >= conv_p.IN_DIM[2]
+ ? A_zero_point
+ : A[(((n * conv_p.IN_DIM[0] + t_in) * conv_p.IN_DIM[1] +
+ h_in) *
+ conv_p.IN_DIM[2] +
+ w_in) *
+ conv_p.IC +
+ c];
+ int b =
+ B[(((q * conv_p.K[1] + r) * conv_p.K[2] + s) *
+ conv_p.IC +
+ c) *
+ conv_p.OC +
+ k];
+ sum += a * b;
+ } // for each c
+ } // for each s
+ } // for each r
+ } // for each q
+ C[(((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] + h) *
+ conv_p.OUT_DIM[2] +
+ w) *
+ conv_p.OC +
+ k] = sum;
+ } // for each k
+ } // for each w
+ } // for each h
+ } // for each t
+ } // for each n
+}
+
void depthwise_3x3_pad_1_ref(
int N,
int H,
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
index e9eaeed..69d060a 100644
--- a/src/RefImplementations.h
+++ b/src/RefImplementations.h
@@ -147,7 +147,14 @@ int32_t clip_16bit(int32_t x);
* The output C is assumed to be in NHoWoC format.
*/
void conv_ref(
- const conv_param_t& conv_p,
+ const conv_param_t<>& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ const std::int8_t* B,
+ std::int32_t* C);
+
+void conv3d_ref(
+ const conv_param_t<3>& conv_p,
const std::uint8_t* A,
std::int32_t A_zero_point,
const std::int8_t* B,
@@ -159,7 +166,18 @@ void conv_ref(
* The output A is assumed to be in NHoWoRSC format.
*/
void im2col_ref(
- const conv_param_t& conv_p,
+ const conv_param_t<>& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ std::uint8_t* Ao);
+
+/*
+ * @brief Reference implementation of im2col 3D operation.
+ * The input A is assumed to be in NTiHiWiC format.
+ * The output A is assumed to be in NToHoWoK0K1K2C format.
+ */
+void im2col3d_ref(
+ const conv_param_t<3>& conv_p,
const std::uint8_t* A,
std::int32_t A_zero_point,
std::uint8_t* Ao);
diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc
index c09c770..3ac8d28 100644
--- a/test/Im2ColFusedRequantizeTest.cc
+++ b/test/Im2ColFusedRequantizeTest.cc
@@ -19,64 +19,64 @@ using namespace std;
namespace fbgemm2 {
-// From Xray OCR
-static vector<conv_param_t> shapes = {
+// From Faster-RCNN with ShuffleNet
+static vector<conv_param_t<>> shapes = {
// MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, pad_h, pad_w
- conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
- conv_param_t(1, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 0, 0),
- conv_param_t(2, 32, 32, 14, 14, 1, 3, 3, 1, 1, 1, 1),
- conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 1, 1, 1, 1),
- // conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 272, 272, 75, 100, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t(1, 272, 272, 109, 75, 1, 3, 3, 1, 1, 1, 1),
- // conv_param_t(1, 544, 544, 24, 63, 1, 3, 3, 1, 1, 1, 1),
- // conv_param_t( 1, 544, 544, 33, 63, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 544, 544, 34, 50, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 544, 544, 36, 63, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 544, 544, 38, 38, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 544, 544, 38, 40, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 544, 544, 47, 38, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t( 1, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
- // conv_param_t(51, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1),
- // conv_param_t( 100, 1088, 1088, 7, 7, 1, 3, 3, 1, 1, 1, 1 ),
- conv_param_t(1, 248, 248, 93, 250, 1, 3, 3, 2, 2, 1, 1),
- // conv_param_t( 1, 248, 248, 128, 250, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 248, 248, 133, 200, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 248, 248, 150, 150, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 248, 248, 150, 151, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 248, 248, 150, 158, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 248, 248, 188, 150, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t(1, 248, 248, 225, 150, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(1, 272, 272, 47, 125, 1, 3, 3, 2, 2, 1, 1),
- // conv_param_t( 1, 272, 272, 64, 125, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 272, 272, 66, 125, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 272, 272, 67, 100, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 272, 272, 75, 75, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 272, 272, 75, 76, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 272, 272, 94, 75, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t( 1, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
- // conv_param_t(51, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
- conv_param_t(3, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1),
- // conv_param_t( 100, 544, 544, 14, 14, 1, 3, 3, 2, 2, 1, 1 ),
- conv_param_t(1, 8, 8, 4, 4, 1, 3, 3, 1, 1, 1, 1),
+ conv_param_t<>(1, 32, 32, {14, 14}, 1, {3, 3}, {1, 1}, {0, 0}),
+ conv_param_t<>(1, 32, 32, {14, 14}, 1, {3, 3}, {1, 1}, {1, 1}),
+ conv_param_t<>(2, 32, 32, {14, 14}, 1, {3, 3}, {1, 1}, {0, 0}),
+ conv_param_t<>(2, 32, 32, {14, 14}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {47, 125}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {64, 125}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {66, 125}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {67, 100}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {75, 75}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {75, 76}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {75, 100}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {94, 75}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {109, 75}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 544, 544, {24, 63}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 544, 544, {33, 63}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 544, 544, {34, 50}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 544, 544, {36, 63}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 544, 544, {38, 38}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 544, 544, {38, 40}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 544, 544, {47, 38}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 1088, 1088, {7, 7}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(51, 1088, 1088, {7, 7}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(100, 1088, 1088, {7, 7}, 1, {3, 3}, {1, 1}, {1, 1}),
+ // conv_param_t<>(1, 248, 248, {93, 250}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 248, 248, {128, 250}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 248, 248, {133, 200}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 248, 248, {150, 150}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 248, 248, {150, 151}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 248, 248, {150, 158}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 248, 248, {188, 150}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 248, 248, {225, 150}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {47, 125}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {64, 125}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {66, 125}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {67, 100}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {75, 75}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {75, 76}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(1, 272, 272, {94, 75}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 544, 544, {14, 14}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(51, 544, 544, {14, 14}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(3, 544, 544, {14, 14}, 1, {3, 3}, {2, 2}, {1, 1}),
+ // conv_param_t<>(100, 544, 544, {14, 14}, 1, {3, 3}, {2, 2}, {1, 1}),
+ conv_param_t<>(1, 8, 8, {4, 4}, 1, {3, 3}, {1, 1}, {1, 1}),
};
TEST(FBGemmIm2colTest, Acc32Test) {
for (auto conv_p : shapes) {
aligned_vector<uint8_t> Aint8(
- conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0);
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0);
aligned_vector<int8_t> Bint8(
- conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0);
+ conv_p.K[0] * conv_p.K[1] * conv_p.IC * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_ref(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f);
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0.0f);
aligned_vector<int32_t> Cint32_fb(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
randFill(Aint8, 0, 80);
int32_t Aint8_zero_point = 43;
@@ -90,7 +90,7 @@ TEST(FBGemmIm2colTest, Acc32Test) {
Cint32_ref.data());
int NDim = conv_p.OC;
- int KDim = conv_p.KH * conv_p.KW * conv_p.IC;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
vector<int32_t> row_offset_buf;
row_offset_buf.resize(
@@ -118,13 +118,17 @@ TEST(FBGemmIm2colTest, Acc32Test) {
// correctness check
for (int n = 0; n < conv_p.MB; ++n) {
- for (int h = 0; h < conv_p.OH; ++h) {
- for (int w = 0; w < conv_p.OW; ++w) {
+ for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) {
for (int k = 0; k < conv_p.OC; ++k) {
int32_t expected = Cint32_ref
- [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k];
+ [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.OC +
+ k];
int32_t actual = Cint32_fb
- [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k];
+ [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.OC +
+ k];
EXPECT_EQ(expected, actual)
<< "Im2Col fused results differ at (" << n << ", " << h << ", "
<< w << ", " << k << ").";
@@ -140,13 +144,13 @@ TEST(FBGemmIm2colTest, Acc32Test) {
TEST(FBGemmIm2colTest, Acc16Test) {
for (auto conv_p : shapes) {
aligned_vector<uint8_t> Aint8(
- conv_p.MB * conv_p.IH * conv_p.IW * conv_p.IC, 0);
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0);
aligned_vector<int8_t> Bint8(
- conv_p.KH * conv_p.KW * conv_p.IC * conv_p.OC, 0);
+ conv_p.K[0] * conv_p.K[1] * conv_p.IC * conv_p.OC, 0);
aligned_vector<int32_t> Cint32_ref(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0.0f);
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0.0f);
aligned_vector<int32_t> Cint32_fb(
- conv_p.MB * conv_p.OH * conv_p.OW * conv_p.OC, 0);
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
randFill(Aint8, 0, 5);
int32_t Aint8_zero_point = 4;
@@ -160,7 +164,7 @@ TEST(FBGemmIm2colTest, Acc16Test) {
Cint32_ref.data());
int NDim = conv_p.OC;
- int KDim = conv_p.KH * conv_p.KW * conv_p.IC;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
vector<int32_t> row_offset_buf;
row_offset_buf.resize(
@@ -188,13 +192,17 @@ TEST(FBGemmIm2colTest, Acc16Test) {
// correctness check
for (int n = 0; n < conv_p.MB; ++n) {
- for (int h = 0; h < conv_p.OH; ++h) {
- for (int w = 0; w < conv_p.OW; ++w) {
+ for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) {
for (int k = 0; k < conv_p.OC; ++k) {
int32_t expected = Cint32_ref
- [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k];
+ [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.OC +
+ k];
int32_t actual = Cint32_fb
- [((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k];
+ [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.OC +
+ k];
EXPECT_EQ(expected, actual)
<< "Im2Col fused results differ at (" << n << ", " << h << ", "
<< w << ", " << k << ").";
@@ -206,5 +214,217 @@ TEST(FBGemmIm2colTest, Acc16Test) {
} // for each shape
} // Acc16Test
+static vector<conv_param_t<3>> shapes_3d = {
+ // MB, IC, OC, IT, IH, IW, G, KT, KH, KW, stride_t, stride_h, stride_w,
+ // pad_t, pad_h, pad_w
+ // conv_param_t<
+ // 3>(1, 3, 64, {32, 112, 112}, 1, {3, 7, 7}, {1, 2, 2}, {1, 3, 3}),
+ // conv_param_t<
+ // 3>(1, 64, 64, {32, 56, 56}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 64, 256, {32, 56, 56}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 256, 64, {32, 56, 56}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 256, 128, {32, 56, 56}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 256, 512, {32, 56, 56}, 1, {1, 1, 1}, {2, 2, 2}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 128, 512, {16, 28, 28}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 512, 128, {16, 28, 28}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 512, 256, {16, 28, 28}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 512, 1024, {16, 28, 28}, 1, {1, 1, 1}, {2, 2, 2}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 256, 1024, {8, 14, 14}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 1024, 256, {8, 14, 14}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 1024, 512, {8, 14, 14}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 1024, 2048, {8, 14, 14}, 1, {1, 1, 1}, {2, 2, 2}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 2048, 512, {8, 14, 14}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ // conv_param_t<
+ // 3>(1, 512, 2048, {4, 7, 7}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+
+ conv_param_t<
+ 3>(1, 3, 4, {32, 112, 112}, 1, {3, 7, 7}, {1, 2, 2}, {1, 3, 3}),
+ conv_param_t<
+ 3>(1, 8, 16, {4, 7, 7}, 1, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}),
+ conv_param_t<
+ 3>(1, 8, 16, {8, 14, 14}, 1, {1, 1, 1}, {2, 2, 2}, {0, 0, 0}),
+};
+
+TEST(FBGemmIm2colTest, 3DAcc32Test) {
+ for (auto conv_p : shapes_3d) {
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IN_DIM[2] *
+ conv_p.IC,
+ 0);
+ aligned_vector<int8_t> Bint8(
+ conv_p.K[0] * conv_p.K[1] * conv_p.K[2] * conv_p.IC * conv_p.OC, 0);
+ aligned_vector<int32_t> Cint32_ref(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OUT_DIM[2] *
+ conv_p.OC,
+ 0.0f);
+ aligned_vector<int32_t> Cint32_fb(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OUT_DIM[2] *
+ conv_p.OC,
+ 0);
+
+ randFill(Aint8, 0, 80);
+ int32_t Aint8_zero_point = 43;
+ randFill(Bint8, -16, 16);
+
+ conv3d_ref(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ Bint8.data(),
+ Cint32_ref.data());
+
+ int NDim = conv_p.OC;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.K[2] * conv_p.IC;
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithIm2Col<uint8_t, int32_t, 3>::rowOffsetBufferSize());
+
+ PackAWithIm2Col<uint8_t, int32_t, 3> packA(
+ conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data());
+
+ PackBMatrix<int8_t, int32_t> packedB(
+ matrix_op_t::NoTranspose, KDim, NDim, Bint8.data(), NDim);
+
+ // no-op output process objects
+ DoNothing<int32_t, int32_t> doNothing32BitObj;
+ memCopy<> memcopyObj(doNothing32BitObj);
+
+ fbgemmPacked(
+ packA,
+ packedB,
+ Cint32_fb.data(),
+ Cint32_fb.data(),
+ NDim,
+ memcopyObj,
+ 0,
+ 1);
+
+ // correctness check
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int t = 0; t < conv_p.OUT_DIM[0]; ++t) {
+ for (int h = 0; h < conv_p.OUT_DIM[1]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[2]; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int32_t expected = Cint32_ref
+ [(((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] + h) *
+ conv_p.OUT_DIM[2] +
+ w) *
+ conv_p.OC +
+ k];
+ int32_t actual = Cint32_fb
+ [(((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] + h) *
+ conv_p.OUT_DIM[2] +
+ w) *
+ conv_p.OC +
+ k];
+ EXPECT_EQ(expected, actual)
+ << "Im2Col fused results differ at (" << n << ", " << t
+ << ", " << h << ", " << w << ", " << k << ").";
+ }
+ }
+ }
+ }
+ }
+ } // for each shape
+} // Acc32Test
+
+
+TEST(FBGemmIm2colTest, 3DAcc16Test) {
+ for (auto conv_p : shapes_3d) {
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IN_DIM[2] *
+ conv_p.IC,
+ 0);
+ aligned_vector<int8_t> Bint8(
+ conv_p.K[0] * conv_p.K[1] * conv_p.K[2] * conv_p.IC * conv_p.OC, 0);
+ aligned_vector<int32_t> Cint32_ref(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OUT_DIM[2] *
+ conv_p.OC,
+ 0.0f);
+ aligned_vector<int32_t> Cint32_fb(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OUT_DIM[2] *
+ conv_p.OC,
+ 0);
+
+ randFill(Aint8, 0, 5);
+ int32_t Aint8_zero_point = 4;
+ randFill(Bint8, -4, 4);
+
+ conv3d_ref(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ Bint8.data(),
+ Cint32_ref.data());
+
+ int NDim = conv_p.OC;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.K[2] * conv_p.IC;
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithIm2Col<uint8_t, int16_t, 3>::rowOffsetBufferSize());
+
+ PackAWithIm2Col<uint8_t, int16_t, 3> packA(
+ conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data());
+
+ PackBMatrix<int8_t, int16_t> packedB(
+ matrix_op_t::NoTranspose, KDim, NDim, Bint8.data(), NDim);
+
+ // no-op output process objects
+ DoNothing<int32_t, int32_t> doNothing32BitObj;
+ memCopy<> memcopyObj(doNothing32BitObj);
+
+ fbgemmPacked(
+ packA,
+ packedB,
+ Cint32_fb.data(),
+ Cint32_fb.data(),
+ NDim,
+ memcopyObj,
+ 0,
+ 1);
+
+ // correctness check
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int t = 0; t < conv_p.OUT_DIM[0]; ++t) {
+ for (int h = 0; h < conv_p.OUT_DIM[1]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[2]; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int32_t expected = Cint32_ref
+ [(((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] + h) *
+ conv_p.OUT_DIM[2] +
+ w) *
+ conv_p.OC +
+ k];
+ int32_t actual = Cint32_fb
+ [(((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] + h) *
+ conv_p.OUT_DIM[2] +
+ w) *
+ conv_p.OC +
+ k];
+ EXPECT_EQ(expected, actual)
+ << "Im2Col fused results differ at (" << n << ", " << t
+ << ", " << h << ", " << w << ", " << k << ").";
+ }
+ }
+ }
+ }
+ }
+ } // for each shape
+} // Acc16Test
} // namespace fbgemm2