diff options
Diffstat (limited to 'test/UniConvTest.cc')
-rw-r--r-- | test/UniConvTest.cc | 714 |
1 files changed, 714 insertions, 0 deletions
diff --git a/test/UniConvTest.cc b/test/UniConvTest.cc new file mode 100644 index 0000000..e9c7ba5 --- /dev/null +++ b/test/UniConvTest.cc @@ -0,0 +1,714 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <algorithm> +#include <iostream> +#include <random> +#include <stdexcept> + +#include <gtest/gtest.h> + +#include "QuantizationHelpers.h" +#include "TestUtils.h" +#include "bench/BenchUtils.h" +#include "fbgemm/Fbgemm.h" +#include "src/RefImplementations.h" + +using namespace std; +using namespace fbgemm; + +// clang-format off +static vector<conv_param_t<>> GetShapes_() { + vector<conv_param_t<>> shapes = { + // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, {pad_t, pad_l, + // pad_b, pad_r} + // Regular + conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 32, {30, 10}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}, {2, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {3, 3}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {1, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 2, 1, 2}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 5}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + // groupwise + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 16, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 32, {10, 30}, 8, {3, 3}, {2, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 2}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 2}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 1}, {2, 1, 2, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 5}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + // DW + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 2, 1, 2}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 2}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 5}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), + // Pointwise + conv_param_t<>(1, 32, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 16, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 2}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 2}, {0, 0, 0, 0}), + conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 1}, {0, 0, 0, 0}), + }; + return shapes; +} +// clang-format on + +namespace { + +// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad +class uniConvTest + : public testing::TestWithParam< + tuple<int, int, int, int, int, int, int, int, int, int>> {}; + +// tuple represents QuantizationGranularity, A symmetric, B symmetric, +// test_bias, test_float_bias +class UniConvQGranTest + : public testing::TestWithParam< + tuple<QuantizationGranularity, bool, bool, bool, bool>> {}; + +}; // namespace + +// Combine only allows at most 10 generators. +INSTANTIATE_TEST_CASE_P( + InstantiationName, + uniConvTest, + ::testing::Combine( + ::testing::ValuesIn({1, 2}), // MB + ::testing::ValuesIn({16, 32}), // IC + ::testing::ValuesIn({16, 32}), // OC + ::testing::ValuesIn({17}), // IT + ::testing::ValuesIn({10, 30, 55}), // IH + ::testing::ValuesIn({10, 30, 55}), // IW + ::testing::ValuesIn({1, 4, 16}), // G + ::testing::ValuesIn({1, 3, 7}), // kernel + ::testing::ValuesIn({1, 2}), // stride + ::testing::ValuesIn({0, 1, 2}))); // pad + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + UniConvQGranTest, + ::testing::Combine( + ::testing::ValuesIn(qGranularityVals), + ::testing::Bool(), // A symmetric + ::testing::Bool(), // B symmetric + ::testing::Bool(), // test_bias + ::testing::Bool())); // test_float_bias +/** + * Test for conv packing + */ +TEST_P(uniConvTest, packingTest) { + int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; + tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); + + conv_param_t<2> conv_p_2d( + MB, + IC, + OC, + {IH, IW}, + G, + {kernel, kernel}, + {stride, stride}, + {pad, pad, pad, pad}); + + int kernel_dim_2d = kernel * kernel; + aligned_vector<int8_t> Bint8_2d( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); + + switch (ConvFastPath<2, int32_t>(conv_p_2d)) { + case optimized_conv_t::depthwise: { + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix is null"; + break; + } + case optimized_conv_t::groupwise: { + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr) + << "Groupwise packed matrix is null"; + break; + } + case optimized_conv_t::pointwise: { + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "Groupwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix is null"; + break; + } + case optimized_conv_t::im2col: { + ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix is null"; + break; + } + } + + conv_param_t<3> conv_p_3d( + MB, + IC, + OC, + {IT, IH, IW}, + G, + {kernel, kernel, kernel}, + {stride, stride, stride}, + {pad, pad, pad, pad, pad, pad}); + + int kernel_dim_3d = kernel * kernel * kernel; + aligned_vector<int8_t> Bint8_3d( + kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); + PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data()); + + switch (ConvFastPath<3, int32_t>(conv_p_3d)) { + case optimized_conv_t::depthwise: { + ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix is null"; + break; + } + case optimized_conv_t::groupwise: { + ASSERT_TRUE(false) << "groupwise are not supported for 3D"; + break; + } + case optimized_conv_t::pointwise: { + ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix is null"; + break; + } + case optimized_conv_t::im2col: { + ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix is null"; + break; + } + } +} + +/** + * Test for packing/unpacking + */ +TEST_P(uniConvTest, packUnpackTest) { + int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; + tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); + + conv_param_t<2> conv_p_2d( + MB, + IC, + OC, + {IH, IW}, + G, + {kernel, kernel}, + {stride, stride}, + {pad, pad, pad, pad}); + + int kernel_dim_2d = kernel * kernel; + + aligned_vector<int8_t> Bint8_2d( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + aligned_vector<int8_t> Bint8_2d_unpacked( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + + PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); + + packedB_2D.unpack(Bint8_2d_unpacked.data()); + + ASSERT_EQ(Bint8_2d, Bint8_2d_unpacked) + << "Original and unpacked data elements are not the same [2D]"; + + conv_param_t<3> conv_p_3d( + MB, + IC, + OC, + {IT, IH, IW}, + G, + {kernel, kernel, kernel}, + {stride, stride, stride}, + {pad, pad, pad, pad, pad, pad}); + + int kernel_dim_3d = kernel * kernel * kernel; + + aligned_vector<int8_t> Bint8_3d( + kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); + + aligned_vector<int8_t> Bint8_3d_unpacked( + kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); + + PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data()); + + packedB_3D.unpack(Bint8_3d_unpacked.data()); + + ASSERT_EQ(Bint8_3d, Bint8_3d_unpacked) + << "Original and unpacked data elements are not the same [3D]"; +} + +TEST(uniConvTest, cornerCases) { + int stride = 1; + conv_param_t<2> conv_p_2d( + 1, // mini-batch + 16, // input channels + 32, // output channels + {28, 28}, // input height/width + 4, // groups + {3, 3}, // kernel height/width + {stride, stride}, // strides + {1, 1, 1, 1}); // padding + + int kernel_dim_2d = conv_p_2d.K[0] * conv_p_2d.K[1]; + + aligned_vector<uint8_t> Aint8( + conv_p_2d.MB * conv_p_2d.IN_DIM[0] * conv_p_2d.IN_DIM[1] * conv_p_2d.IC); + aligned_vector<int8_t> Bint8_2d( + kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); + aligned_vector<int32_t> Cint32_fb( + conv_p_2d.MB * conv_p_2d.OUT_DIM[0] * conv_p_2d.OUT_DIM[1] * + conv_p_2d.OC); + aligned_vector<uint8_t> Cint8_fb(Cint32_fb.size(), 0); + + // A matrix (input activations) + randFill<uint8_t>(Aint8, 0, 5); + int32_t Aint8_zero_point = 4; + + // B matrix (weights) + randFill<int8_t>(Bint8_2d, -4, 4); + aligned_vector<int32_t> Bint8_zero_point(1); + randFill(Bint8_zero_point, -3, -1); + + aligned_vector<float> C_multiplier(Bint8_zero_point.size()); + randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2); + int32_t C_zero_point = 5; + + PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); + + vector<int32_t> col_offsets(conv_p_2d.OC); + + DoNothing<> doNothingObj{}; + ReQuantizeOutput<false, QuantizationGranularity::TENSOR> outputProcObj( + doNothingObj, + C_multiplier.data(), + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, // row offsets + col_offsets.data(), + nullptr, // bias + conv_p_2d.OC, + conv_p_2d.G); + + try { + conv_p_2d.stride[0] = 2; + fbgemmConv( + conv_p_2d, + Aint8.data(), + packedB_2D, + Cint8_fb.data(), + Cint32_fb.data(), + outputProcObj, + 0, + 1); + } catch (std::logic_error const& err) { + std::string s(err.what()); + EXPECT_TRUE(s.rfind("[FBGEMM_CONV_ERROR]", 0) == 0); + } + + // reset + conv_p_2d.stride[0] = stride; + // this should run fine + fbgemmConv( + conv_p_2d, + Aint8.data(), + packedB_2D, + Cint8_fb.data(), + Cint32_fb.data(), + outputProcObj, + 0, + 1); +} + +/** + * @brief Unit test for uint8 activations, int8 weights, and 32-bit + * accumulation. Output processing: requantization -> nothing + */ +TEST_P(UniConvQGranTest, requantizeTest) { + vector<conv_param_t<>> shapes(GetShapes_()); + QuantizationGranularity q_granularity; + bool a_symmetric, b_symmetric; + bool test_bias, test_float_bias; + tie(q_granularity, a_symmetric, b_symmetric, test_bias, test_float_bias) = + GetParam(); + + for (auto conv_p : shapes) { + int R = conv_p.K[0]; + int S = conv_p.K[1]; + int G = conv_p.G; + int OC = conv_p.OC; + int OH = conv_p.OUT_DIM[0]; + int OW = conv_p.OUT_DIM[1]; + int IC_per_G = conv_p.IC / conv_p.G; + int OC_per_G = conv_p.OC / conv_p.G; + + // activations + aligned_vector<uint8_t> Aint8( + conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0); + + // weights + // The weight matrix is in layout G K/G (R S C/G) + aligned_vector<int8_t> Bint8(R * S * conv_p.G * IC_per_G * OC_per_G, 0); + aligned_vector<int8_t> Bint8_tr(R * S * G * IC_per_G * OC_per_G, 0); + + aligned_vector<int32_t> Cint32_ref(conv_p.MB * OH * OW * OC, 0); + aligned_vector<int32_t> Cint32_fb(Cint32_ref.size(), 0); + aligned_vector<uint8_t> Cint8_ref(Cint32_ref.size(), 0); + aligned_vector<uint8_t> Cint8_fb(Cint32_ref.size(), 0); + + randFill<uint8_t>(Aint8, 0, 5); + int32_t Aint8_zero_point = a_symmetric ? 0 : 4; + + randFill<int8_t>(Bint8, -4, 4); + + // computing column offset + vector<int32_t> col_offsets(G * OC_per_G); + + int ncols_per_quant_group = G * OC_per_G; + if (q_granularity == QuantizationGranularity::GROUP) { + ncols_per_quant_group = OC_per_G; + } else if (q_granularity == QuantizationGranularity::OUT_CHANNEL) { + ncols_per_quant_group = 1; + } + + aligned_vector<int32_t> Bint8_zero_point( + G * OC_per_G / ncols_per_quant_group); + if (b_symmetric) { + randFill(Bint8_zero_point, -3, 3); + } else { + randFill(Bint8_zero_point, 0, 0); + } + + // matrix dimensions after im2col for each GEMM. + // For each group, there is one GEMM of the following dimensions + int MDim = conv_p.MB * OH * OW; + int NDim = OC_per_G; + int KDim = R * S * IC_per_G; + + vector<uint8_t> Aint8_im2col(MDim * KDim * G); + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); + + vector<int32_t> row_offsets(MDim); + + // activation_scale * weight_scale + aligned_vector<float> act_times_w_scale(Bint8_zero_point.size()); + randFill(act_times_w_scale, 0.1234f / 2, 0.1234f * 3 / 2); + + float out_scale = 2.0f; + aligned_vector<float> C_multiplier(Bint8_zero_point.size()); + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + C_multiplier.begin(), + [&out_scale](float i) { return i / out_scale; }); + + int32_t C_zero_pt = 5; + + // initialize bias + aligned_vector<int32_t> bias_int32(OC); + aligned_vector<float> bias_fp32(OC); + if (test_bias) { + randFill(bias_int32, -8, 8); + } + + // floating point bias + if (test_float_bias) { + if (q_granularity == QuantizationGranularity::TENSOR) { + transform( + bias_int32.begin(), + bias_int32.end(), + bias_fp32.begin(), + [&act_times_w_scale](float i) { return i * act_times_w_scale[0]; }); + } else if (q_granularity == QuantizationGranularity::GROUP) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < OC_per_G; ++c) { + bias_fp32[g * OC_per_G + c] = act_times_w_scale[g] * + static_cast<float>(bias_int32[g * OC_per_G + c]); + } + } + } else { // OUT_CHANNEL + transform( + act_times_w_scale.begin(), + act_times_w_scale.end(), + bias_int32.begin(), + bias_fp32.begin(), + multiplies<float>()); + } + } + // reference implementation + // conv_ref expects weights to be in G (R S C/G) K/G + int8_t* rightBData = Bint8.data(); + transposeConvWeights(conv_p, Bint8.data(), Bint8_tr.data()); + rightBData = Bint8_tr.data(); + for (int g = 0; g < G; ++g) { + col_offsets_with_zero_pt_s8acc32_ref( + R * S * IC_per_G, + OC_per_G, + OC_per_G, + rightBData + g * R * S * IC_per_G * OC_per_G, + Bint8_zero_point.data() + g * OC_per_G / ncols_per_quant_group, + col_offsets.data() + g * OC_per_G, + ncols_per_quant_group); + } + conv_ref( + conv_p, Aint8.data(), Aint8_zero_point, rightBData, Cint32_ref.data()); + + for (int g = 0; g < G; ++g) { + row_offsets_u8acc32_ref( + MDim, + KDim, + KDim * G, + Aint8_im2col.data() + g * KDim, + row_offsets.data()); + + requantize_u8acc32_ref( + MDim, + NDim, + G * NDim, + Cint32_ref.data() + g * NDim, + Cint8_ref.data() + g * NDim, + C_multiplier.data() + g * NDim / ncols_per_quant_group, + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data() + g * NDim / ncols_per_quant_group, + row_offsets.data(), + col_offsets.data() + g * NDim, + test_bias ? bias_int32.data() + g * NDim : nullptr, + ncols_per_quant_group); + } + + PackWeightsForConv<2> packedWeights(conv_p, Bint8.data()); + + // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv + // #ifdef _OPENMP + // #pragma omp parallel + // #endif + { + vector<int32_t> row_offset_buf(rowOffsetBufferSizeGConv(conv_p)); + + DoNothing<> doNothingObj{}; + + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + + if (q_granularity == QuantizationGranularity::TENSOR) { + if (test_float_bias) { + ReQuantizeOutput<false, QuantizationGranularity::TENSOR, float> + reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_fp32.data() : nullptr, + G * NDim, + G, + act_times_w_scale.data()); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else { + ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_int32.data() : nullptr, + G * NDim, + G); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + } + + } else if (q_granularity == QuantizationGranularity::GROUP) { + if (test_float_bias) { + ReQuantizeOutput<false, QuantizationGranularity::GROUP, float> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_fp32.data() : nullptr, + G * NDim, + G, + act_times_w_scale.data()); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else { + ReQuantizeOutput<false, QuantizationGranularity::GROUP> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_int32.data() : nullptr, + G * NDim, + G); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + } + + } else { + if (test_float_bias) { + ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL, float> + reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_fp32.data() : nullptr, + G * NDim, + G, + act_times_w_scale.data()); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + + } else { + ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL> reqObj( + doNothingObj, + C_multiplier.data(), + C_zero_pt, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + test_bias ? bias_int32.data() : nullptr, + G * NDim, + G); + + fbgemmConv( + conv_p, + Aint8.data(), + packedWeights, + Cint8_fb.data(), + Cint32_fb.data(), + reqObj, + tid, + num_threads); + } + } + } // omp parallel + + compare_validate_buffers( + Cint8_ref.data(), + Cint8_fb.data(), + MDim, + NDim * G, + NDim * G, + static_cast<uint8_t>(0)); + } // for each shape +} |