Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2018-11-29 06:41:59 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-11-29 06:44:08 +0300
commit027de07a11a0460fd1daffb026d50dba0e56eb79 (patch)
tree5e0f497059d6a22b18de8508c2031ffdbc9f52d3
parent90535d3da35f9d3da6a8dbd62da0c68d01696924 (diff)
sparse convolution output processing (#27)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/27 DoSpmdmOnInpBuffer can't be used together with PackAWithIm2Col because DoSpmdmOnInpBuffer expects im2col'ed A matrix. This diff implements DoSConvOnInpBuffer that does sparse convolution directly on A input without im2col. The performance is well optimized and need to see if this implementation is good enough to get good resnet50 performance. Reviewed By: dskhudia Differential Revision: D13192336 fbshipit-source-id: 2076555ba9749e111afbaec408a2bfa0f55bd5bc
-rw-r--r--include/fbgemm/Fbgemm.h47
-rw-r--r--include/fbgemm/FbgemmI8Spmdm.h35
-rw-r--r--include/fbgemm/OutputProcessing-inl.h12
-rw-r--r--src/ExecuteKernelU8S8.cc18
-rw-r--r--src/Fbgemm.cc25
-rw-r--r--src/FbgemmI8Spmdm.cc72
-rw-r--r--test/Im2ColFusedRequantizeTest.cc232
7 files changed, 429 insertions, 12 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 17a07e5..0ef6a82 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -808,7 +808,7 @@ class ReluOutput {
};
/**
- * @brief Perform Sparse-Matrix * Dense-Matrix as a part the of output
+ * @brief Perform Dense-Matrix * Sparse-Matrix as a part the of output
* processing pipeline.
*
* SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer
@@ -847,6 +847,51 @@ class DoSpmdmOnInpBuffer {
const int groups_;
};
+/**
+ * @brief Perform Dense-Matrix * Sparse-Matrix as a part the of output
+ * processing pipeline.
+ *
+ * SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer
+ * (inp). After modifying the input buffer, pass it to the next op.
+ * When groups > 1, each group is numRows() x (numCols()/groups) matrix.
+ */
+template <
+ typename outT = std::int32_t,
+ typename inT = std::int32_t,
+ typename nextOPType = DoNothing<inT, inT>>
+class DoSConvOnInpBuffer {
+ public:
+ using outType = outT;
+ using inpType = inT;
+ DoSConvOnInpBuffer(
+ nextOPType& nextop,
+ const std::uint8_t* A,
+ const conv_param_t<>& conv_p,
+ std::int32_t A_zero_point,
+ const CompressedSparseColumn& B_csc,
+ int groups = 1)
+ : nextop_(nextop),
+ A_(A),
+ conv_p_(conv_p),
+ A_zero_point_(A_zero_point),
+ B_csc_(B_csc) {}
+
+ template <inst_set_t instSet>
+ inline int f(
+ outT* out,
+ inT* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in) const;
+
+ private:
+ nextOPType& nextop_;
+ const std::uint8_t* A_;
+ const conv_param_t<>& conv_p_;
+ const std::int32_t A_zero_point_;
+ const CompressedSparseColumn& B_csc_;
+};
+
enum class QuantizationGranularity {
TENSOR,
GROUP,
diff --git a/include/fbgemm/FbgemmI8Spmdm.h b/include/fbgemm/FbgemmI8Spmdm.h
index 3e040ad..ec72473 100644
--- a/include/fbgemm/FbgemmI8Spmdm.h
+++ b/include/fbgemm/FbgemmI8Spmdm.h
@@ -8,6 +8,7 @@
#include <cstdint>
#include <vector>
+#include "ConvUtils.h"
#include "Utils.h"
// #define FBGEMM_MEASURE_TIME_BREAKDOWN
@@ -21,6 +22,7 @@ extern double spmdm_transpose_32xN_time;
extern double spmdm_compute_time;
extern double spmdm_transpose_Nx32_time;
extern double spmdm_run_time;
+extern double sconv_run_time;
#endif
namespace fbgemm {
@@ -46,6 +48,19 @@ class CompressedSparseColumn {
std::vector<std::int8_t>& Values() {
return values_;
}
+ std::vector<std::int16_t>& KHs() {
+ return kh_;
+ }
+ std::vector<std::int16_t>& KWs() {
+ return kw_;
+ }
+ /**
+ * ICs include group: i.e. for ith input channels withint group g, ICs contain
+ * g*(groups_per_input_channels) + i
+ */
+ std::vector<std::int16_t>& ICs() {
+ return ic_;
+ }
std::size_t NumOfRows() const {
return num_rows_;
@@ -83,12 +98,28 @@ class CompressedSparseColumn {
std::int32_t* C,
int ldc) const;
+ void SparseConv(
+ const conv_param_t<>& conv_p,
+ const block_type_t& block,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ bool accumulation,
+ std::int32_t* C,
+ int ldc) const;
+
private:
const std::size_t num_rows_;
- std::vector<std::int32_t> colptr_;
- std::vector<std::int16_t> rowidx_;
+ std::vector<std::int32_t> colptr_; // corresponds to out channels
std::vector<std::int8_t> values_;
+ // For SpMDM
+ std::vector<std::int16_t> rowidx_; // kh kw ic are flattened with im2col
+
+ // For direct sparse convolution
+ std::vector<std::int16_t> kh_;
+ std::vector<std::int16_t> kw_;
+ std::vector<std::int16_t> ic_; // in channels
+
// Cache IsHyperSparse to minimize its overhead.
mutable bool hyper_sparse_;
diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h
index 88a10bc..e193b40 100644
--- a/include/fbgemm/OutputProcessing-inl.h
+++ b/include/fbgemm/OutputProcessing-inl.h
@@ -44,6 +44,18 @@ inline int DoSpmdmOnInpBuffer<outT, inT, nextOPType>::f(
return nextop_.template f<instSet>(out, inp, block, ld_out, ld_in);
}
+template <typename outT, typename inT, typename nextOPType>
+template <inst_set_t instSet>
+inline int DoSConvOnInpBuffer<outT, inT, nextOPType>::f(
+ outT* out,
+ inT* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in) const {
+ B_csc_.SparseConv(conv_p_, block, A_, A_zero_point_, true, inp, ld_in);
+ return nextop_.template f<instSet>(out, inp, block, ld_out, ld_in);
+}
+
template <
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index f1ec882..152d7f1 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -381,6 +381,24 @@ INSTANTIATE_Q_GRANS(true);
#undef INSTANTIATE_Q_GRANS
#undef INSTANTIATE_BASE
+#define INSTANTIATE_BASE(RELU, Q_GRAN) \
+ template class ExecuteKernel< \
+ PackAWithIm2Col<uint8_t, int16_t>, \
+ PackBMatrix<int8_t, int16_t>, \
+ uint8_t, \
+ DoSConvOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<RELU, Q_GRAN>>>;
+
+#define INSTANTIATE_Q_GRANS(RELU) \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP); \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL);
+
+INSTANTIATE_Q_GRANS(false);
+INSTANTIATE_Q_GRANS(true);
+
+#undef INSTANTIATE_Q_GRANS
+#undef INSTANTIATE_BASE
+
template class ExecuteKernel<
PackAWithRowOffset<uint8_t, int16_t>,
PackBMatrix<int8_t, int16_t>,
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index a8bf02f..6623fe7 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -376,6 +376,31 @@ INSTANTIATE_Q_GRANS(true);
#undef INSTANTIATE_Q_GRANS
#undef INSTANTIATE_BASE
+#define INSTANTIATE_BASE(RELU, Q_GRAN) \
+ template void fbgemmPacked( \
+ PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>& packA, \
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, \
+ uint8_t* C, \
+ int32_t* C_buffer, \
+ uint32_t ldc, \
+ const DoSConvOnInpBuffer< \
+ uint8_t, \
+ int32_t, \
+ ReQuantizeOutput<RELU, Q_GRAN>>& outProcess, \
+ int thread_id, \
+ int num_threads);
+
+#define INSTANTIATE_Q_GRANS(RELU) \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP); \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL);
+
+INSTANTIATE_Q_GRANS(false);
+INSTANTIATE_Q_GRANS(true);
+
+#undef INSTANTIATE_Q_GRANS
+#undef INSTANTIATE_BASE
+
template void fbgemmPacked(
PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc
index 12e1cb2..125b26d 100644
--- a/src/FbgemmI8Spmdm.cc
+++ b/src/FbgemmI8Spmdm.cc
@@ -21,6 +21,7 @@ double spmdm_transpose_32xN_time = 0.0;
double spmdm_compute_time = 0.0;
double spmdm_transpose_Nx32_time = 0.0;
double spmdm_run_time = 0.0;
+double sconv_run_time = 0.0;
#endif
using namespace std;
@@ -222,8 +223,8 @@ void CompressedSparseColumn::SpMDM(
t_very_start = std::chrono::high_resolution_clock::now();
#endif
- uint8_t A_buffer[K * 32] __attribute__((aligned(64)));
- int32_t C_buffer[N * 32] __attribute__((aligned(64)));
+ alignas(64) uint8_t A_buffer[K * 32];
+ alignas(64) int32_t C_buffer[N * 32];
// If we compute C = C + A * B, where B is a sparse matrix in CSC format, for
// each non-zero in B, we'd need to access the corresponding column in A.
@@ -269,7 +270,7 @@ void CompressedSparseColumn::SpMDM(
for (int i1 = block.row_start; i1 < i_end; i1 += 32) {
// Transpose 32 x K submatrix of A
if (i_end - i1 < 32) {
- uint8_t A_temp_buffer[K * 32] __attribute__((aligned(64)));
+ alignas(64) uint8_t A_temp_buffer[K * 32];
for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) {
transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32);
}
@@ -505,4 +506,69 @@ void CompressedSparseColumn::SpMDM(
#endif
}
+void CompressedSparseColumn::SparseConv(
+ const conv_param_t<>& conv_p,
+ const block_type_t& block,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ bool accumulation,
+ int32_t* C,
+ int ldc) const {
+ int K = NumOfRows();
+ int N = block.col_size;
+
+ if (K == 0 || N == 0) {
+ return;
+ }
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ std::chrono::time_point<std::chrono::high_resolution_clock> t_start, t_end;
+ double dt;
+ t_start = std::chrono::high_resolution_clock::now();
+#endif
+
+ // TODO: if not hyper sparse, transpose a block of A matrix as in SpMDM.
+ if (!accumulation) {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size;
+ ++j) {
+ C[(i - block.row_start) * ldc + j - block.col_start] = 0;
+ }
+ }
+ }
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) {
+ int v = values_[k];
+ for (int i = block.row_start; i < block.row_start + block.row_size;
+ ++i) {
+ int ow = i % conv_p.OUT_DIM[1];
+ int oh = i / conv_p.OUT_DIM[1] % conv_p.OUT_DIM[0];
+ int n = i / conv_p.OUT_DIM[1] / conv_p.OUT_DIM[0];
+ assert(n < conv_p.MB);
+ int iw = -conv_p.pad[1] + ow * conv_p.stride[1] + kw_[k];
+ int ih = -conv_p.pad[0] + oh * conv_p.stride[0] + kh_[k];
+
+ if (ih >= 0 && ih < conv_p.IN_DIM[0] && iw >= 0 &&
+ iw < conv_p.IN_DIM[1]) {
+ C[(i - block.row_start) * ldc + j - block.col_start] +=
+ A[((n * conv_p.IN_DIM[0] + ih) * conv_p.IN_DIM[1] + iw) *
+ conv_p.IC +
+ ic_[k]] *
+ v;
+ } else {
+ C[(i - block.row_start) * ldc + j - block.col_start] +=
+ A_zero_point * v;
+ }
+ }
+ }
+ } // for each column of B
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ t_end = std::chrono::high_resolution_clock::now();
+ dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
+ .count();
+ sconv_run_time += (dt);
+#endif
+}
+
} // namespace fbgemm
diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc
index d8f3f7a..30a905d 100644
--- a/test/Im2ColFusedRequantizeTest.cc
+++ b/test/Im2ColFusedRequantizeTest.cc
@@ -6,6 +6,7 @@
*/
#include <cmath>
#include <cstdio>
+#include <random>
#ifdef _OPENMP
#include <omp.h>
@@ -139,8 +140,8 @@ static void Im2colTest() {
im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data());
// computing column offset
- vector<int32_t> col_offsets(groups * NDim);
- for (int g = 0; g < groups; ++g) {
+ vector<int32_t> col_offsets(conv_p.G * NDim);
+ for (int g = 0; g < conv_p.G; ++g) {
col_offsets_with_zero_pt_s8acc32_ref(
KDimPerGroup,
NDim,
@@ -158,7 +159,7 @@ static void Im2colTest() {
Bint8.data(),
Cint32_ref.data());
- for (int g = 0; g < groups; ++g) {
+ for (int g = 0; g < conv_p.G; ++g) {
row_offsets_u8acc32_ref(
MDim,
KDimPerGroup,
@@ -278,6 +279,225 @@ TEST_P(fbgemmIm2colTest, Acc16Test) {
}
}
+template<QuantizationGranularity Q_GRAN>
+void SConvTest() {
+ for (auto conv_p : shapes) {
+ for (int groups : {1, 4}) {
+ if (conv_p.IC % groups != 0 || conv_p.OC % groups != 0) {
+ continue;
+ }
+ conv_p.G = groups;
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC);
+ aligned_vector<int8_t> Bint8(
+ conv_p.K[0] * conv_p.K[1] * conv_p.IC * conv_p.OC);
+ aligned_vector<int32_t> Cint32_ref(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC);
+ aligned_vector<uint8_t> Cint8_ref(Cint32_ref.size());
+ aligned_vector<int32_t> Cint32_fb(Cint32_ref.size());
+ aligned_vector<uint8_t> Cint8_fb(Cint32_ref.size());
+
+ int ncols_per_quant_group = conv_p.OC;
+ if (Q_GRAN == QuantizationGranularity::GROUP) {
+ ncols_per_quant_group = conv_p.OC / conv_p.G;
+ } else if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ ncols_per_quant_group = 1;
+ }
+ int32_t Aint8_zero_point;
+ aligned_vector<int32_t> Bint8_zero_point(
+ conv_p.OC / ncols_per_quant_group);
+ randFill<uint8_t>(Aint8, 0, 5);
+ Aint8_zero_point = 4;
+ randFill<int8_t>(Bint8, -4, 4);
+ randFill(Bint8_zero_point, -3, -1);
+
+ aligned_vector<float> C_multiplier(Bint8_zero_point.size());
+ randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2);
+ int32_t C_zero_pt = 5;
+
+ int MDim = conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1];
+ int NDim = conv_p.OC / conv_p.G;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
+ int KDimPerGroup = KDim / conv_p.G;
+
+ // computing row offset
+ vector<int32_t> row_offsets(MDim);
+ vector<uint8_t> Aint8_im2col(MDim * KDim);
+ im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data());
+
+ // computing column offset
+ vector<int32_t> col_offsets(conv_p.G * NDim);
+ for (int g = 0; g < conv_p.G; ++g) {
+ col_offsets_with_zero_pt_s8acc32_ref(
+ KDimPerGroup,
+ NDim,
+ NDim,
+ Bint8.data() + g * KDimPerGroup * NDim,
+ Bint8_zero_point.data() + g * NDim / ncols_per_quant_group,
+ col_offsets.data() + g * NDim,
+ ncols_per_quant_group);
+ }
+
+ conv_ref(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ Bint8.data(),
+ Cint32_ref.data());
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ Aint8_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+
+ requantize_u8acc32_ref(
+ MDim,
+ NDim,
+ conv_p.G * NDim,
+ Cint32_ref.data() + g * NDim,
+ Cint8_ref.data() + g * NDim,
+ C_multiplier.data() + g * NDim / ncols_per_quant_group,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data() + g * NDim / ncols_per_quant_group,
+ row_offsets.data(),
+ col_offsets.data() + g * NDim,
+ nullptr,
+ ncols_per_quant_group);
+ }
+
+ float density = 0.0001f;
+ CompressedSparseColumn B_csc(KDimPerGroup, conv_p.G * NDim);
+ default_random_engine eng;
+ binomial_distribution<> per_col_nnz_dist(KDimPerGroup, density);
+
+ // TODO: refactor CSC construction as a reusable function
+ vector<int> row_indices(KDimPerGroup);
+ int total_nnz = 0;
+ int ic_per_group = conv_p.IC / conv_p.G;
+ for (int g = 0; g < conv_p.G; ++g) {
+ for (int j = 0; j < NDim; ++j) {
+ B_csc.ColPtr()[g * NDim + j] = total_nnz;
+
+ int nnz_of_j = per_col_nnz_dist(eng);
+ total_nnz += nnz_of_j;
+
+ iota(row_indices.begin(), row_indices.end(), 0);
+ shuffle(row_indices.begin(), row_indices.end(), eng);
+ sort(row_indices.begin(), row_indices.begin() + nnz_of_j);
+
+ for (int kidx = 0; kidx < nnz_of_j; ++kidx) {
+ int rowidx = row_indices[kidx];
+ int ic = g * ic_per_group + rowidx % ic_per_group;
+ int kw = rowidx / ic_per_group % conv_p.K[1];
+ int kh = rowidx / ic_per_group / conv_p.K[1];
+ assert(kh < conv_p.K[0]);
+
+ B_csc.KHs().push_back(kh);
+ B_csc.KWs().push_back(kw);
+ B_csc.ICs().push_back(ic);
+
+ int8_t* bptr = &Bint8[(g * KDimPerGroup + rowidx) * NDim + j];
+ B_csc.Values().push_back(*bptr);
+ *bptr = 0;
+ }
+ }
+ }
+ B_csc.ColPtr()[conv_p.G * NDim] = total_nnz;
+
+ PackBMatrix<int8_t, int16_t> packedB(
+ matrix_op_t::NoTranspose,
+ KDim,
+ NDim,
+ Bint8.data(),
+ NDim,
+ nullptr,
+ conv_p.G);
+
+#ifdef _OPENMP
+#pragma omp parallel
+#endif
+ {
+ vector<int32_t> row_offset_buf(
+ PackAWithIm2Col<uint8_t, int16_t>::rowOffsetBufferSize());
+
+ PackAWithIm2Col<uint8_t, int16_t> packA(
+ conv_p,
+ Aint8.data(),
+ nullptr,
+ Aint8_zero_point,
+ row_offset_buf.data());
+
+ DoNothing<> doNothingObj{};
+ ReQuantizeOutput<false, Q_GRAN> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ packA.getRowOffsetBuffer(),
+ col_offsets.data(),
+ nullptr,
+ conv_p.G * NDim,
+ conv_p.G);
+ DoSConvOnInpBuffer<
+ ReQuantizeOutput<false>::outType,
+ int32_t,
+ ReQuantizeOutput<false, Q_GRAN>>
+ sconvObj(reqObj, Aint8.data(), conv_p, Aint8_zero_point, B_csc);
+
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+
+ fbgemmPacked(
+ packA,
+ packedB,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ conv_p.G * NDim,
+ sconvObj,
+ tid,
+ num_threads);
+ } // omp parallel
+
+ // correctness check
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int32_t expected = Cint8_ref
+ [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.OC +
+ k];
+ int32_t actual = Cint8_fb
+ [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.OC +
+ k];
+ EXPECT_EQ(expected, actual)
+ << "Im2Col fused results differ at (" << n << ", " << h
+ << ", " << w << ", " << k << ").";
+ }
+ }
+ }
+ }
+ } // for each groups
+ } // for each shape
+}
+
+TEST_P(fbgemmIm2colTest, SConvTest) {
+ QuantizationGranularity q_granularity = GetParam();
+ if (q_granularity == QuantizationGranularity::TENSOR) {
+ SConvTest<QuantizationGranularity::TENSOR>();
+ } else if (q_granularity == QuantizationGranularity::GROUP) {
+ SConvTest<QuantizationGranularity::GROUP>();
+ } else {
+ SConvTest<QuantizationGranularity::OUT_CHANNEL>();
+ }
+}
+
static vector<conv_param_t<3>> shapes_3d = {
// MB, IC, OC, IT, IH, IW, G, KT, KH, KW, stride_t, stride_h, stride_w,
// pad_t, pad_h, pad_w
@@ -410,8 +630,8 @@ static void Im2col3DTest() {
im2col3d_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data());
// computing column offset
- vector<int32_t> col_offsets(groups * NDim);
- for (int g = 0; g < groups; ++g) {
+ vector<int32_t> col_offsets(conv_p.G * NDim);
+ for (int g = 0; g < conv_p.G; ++g) {
col_offsets_with_zero_pt_s8acc32_ref(
KDimPerGroup,
NDim,
@@ -429,7 +649,7 @@ static void Im2col3DTest() {
Bint8.data(),
Cint32_ref.data());
- for (int g = 0; g < groups; ++g) {
+ for (int g = 0; g < conv_p.G; ++g) {
row_offsets_u8acc32_ref(
MDim,
KDimPerGroup,