diff options
author | Jongsoo Park <jongsoo@fb.com> | 2019-02-01 22:50:44 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-01 22:53:50 +0300 |
commit | d08c09ac24fc16e4fbc150b81d01b9da41309611 (patch) | |
tree | 20828ba9008f20aba21541e9c2417c23018a1fcd | |
parent | a98f5d8ffe070b0124dcd34963e6ae55c7864407 (diff) |
specialized requantization for gconv (#61)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/61
requantization was the bottleneck of group conv with 4 channels per group. This diff implements a version of requantization specialized for group conv with 4 channels per group.
TODO: generalize for different group conv
Reviewed By: dskhudia
Differential Revision: D13831466
fbshipit-source-id: 1ac7225d3133a2304c5b07730374584afc6ec259
-rw-r--r-- | include/fbgemm/Fbgemm.h | 38 | ||||
-rw-r--r-- | include/fbgemm/QuantUtilsAvx2.h | 14 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 240 | ||||
-rw-r--r-- | src/QuantUtilsAvx2.cc | 349 |
4 files changed, 611 insertions, 30 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index f49da57..47c514d 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -1024,6 +1024,25 @@ class FBGEMM_API ReQuantizeOutput { int ld_out, int ld_in) const; + const float* getCMultiplier() const { + return C_multiplier_; + } + const std::int32_t getCZeroPoint() const { + return C_zero_point_; + } + const std::int32_t* getBZeroPoint() const { + return Bq_zero_point_; + } + const std::int32_t* getColOffsets() const { + return q_col_offsets_; + } + const std::int32_t* getBias() const { + return bias_; + } + const std::uint32_t getNCols() const { + return ncols_; + } + private: nextOPType& nextop_; const float* C_multiplier_; @@ -1176,6 +1195,25 @@ FBGEMM_API void fbgemmGroupwiseConv( const processOutputType& outProcess, int thread_id, int num_threads); + +template < + typename packed_W, + typename outType, + bool FUSE_RELU, + QuantizationGranularity Q_GRAN, + int SPATIAL_DIM = 2> +FBGEMM_API void fbgemmGroupwiseConv( + const conv_param_t<SPATIAL_DIM>& conv_param, + const std::uint8_t* activations, + std::int32_t a_zero_point, + std::int32_t* rowOffsetBuf, + packed_W& packed_weights, + outType* out, + std::int32_t* outBuffer, + const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess, + int thread_id, + int num_threads); + /** * @return Size of row offset buffer in number of elements needed for * fbgemmGroupwiseConv diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index ec7c6a5..40b830c 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -80,4 +80,18 @@ FBGEMM_API void requantizeOutputProcessingAvx2( int ld_in, const requantizationParams_t& r); +template < + bool A_SYMMETRIC, + bool B_SYMMETRIC, + QuantizationGranularity Q_GRAN, + bool HAS_BIAS, + bool FUSE_RELU> +FBGEMM_API void requantizeOutputProcessingGConv4Avx2( + std::uint8_t* out, + const std::int32_t* inp, + const block_type_t& block, + int ld_out, + int ld_in, + const requantizationParams_t& r); + } // namespace fbgemm diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index 8e851b2..55389b3 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -1321,12 +1321,14 @@ GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( return fn; } +namespace { + template < typename packed_W, typename outType, typename processOutputType, int SPATIAL_DIM> -void fbgemmGroupwiseConv( +void fbgemmGroupwiseConvBase_( const conv_param_t<SPATIAL_DIM>& conv_param, const std::uint8_t* activations, std::int32_t a_zero_point, @@ -1345,11 +1347,11 @@ void fbgemmGroupwiseConv( int K_per_G = conv_param.OC / G; int C_per_G = conv_param.IC / G; int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; + int ih_iw = conv_param.IN_DIM[0] * conv_param.IN_DIM[1]; static_assert(SPATIAL_DIM == 2, "3D conv not supported yet"); - int32_t* rowOffsetTrDest = - rowOffsetBuf + 8 * conv_param.IN_DIM[0] * conv_param.IN_DIM[1]; + int32_t* rowOffsetTrDest = rowOffsetBuf + 8 * ih_iw; if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)) { assert(G % 8 == 0); // generate convolution kernel @@ -1359,8 +1361,7 @@ void fbgemmGroupwiseConv( jit_rowoffset_kernel_fp fpRowoffset = getOrCreateRowOffsetKernel(conv_param, a_zero_point); for (int i = 0; i < MB; ++i) { - const uint8_t* actStartBatch = activations + - i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * conv_param.IC; + const uint8_t* actStartBatch = activations + i * ih_iw * conv_param.IC; for (int gOuter = 0; gOuter < G; gOuter += 8) { // for C_per_G == 4 and K_per_G == 4, row offset is calcualted for 8 // groups at a time The result is row offsets in the format IH*IW x G @@ -1372,12 +1373,12 @@ void fbgemmGroupwiseConv( rowOffsetBuf); // Transpose to get row offsets in the format G x IH*IW internal::transpose_8x8( - conv_param.IN_DIM[0] * conv_param.IN_DIM[1], + ih_iw, 8, reinterpret_cast<const float*>(rowOffsetBuf), 8, reinterpret_cast<float*>(rowOffsetTrDest), - conv_param.IN_DIM[0] * conv_param.IN_DIM[1]); + ih_iw); int gLimit = gOuter + 8; for (int g = gOuter; g < gLimit; g += 2) { int32_t* currOutBuf = @@ -1396,18 +1397,14 @@ void fbgemmGroupwiseConv( for (int j = 0; j < 2; ++j) { // calculateRowOffsets( // conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j); - int32_t* rowOffsetForCurG = rowOffsetTrDest + - ((g - gOuter) + j) * conv_param.IN_DIM[0] * - conv_param.IN_DIM[1]; + int32_t* rowOffsetForCurG = + rowOffsetTrDest + ((g - gOuter) + j) * ih_iw; // compare_buffers(rowOffsetBuf, rowOffsetForCurG, // conv_param.IN_DIM[0]*conv_param.IN_DIM[1], 1, 1, 100); // outProcess expects rowOffsetBuf to contain row offsets for the // current group - memcpy( - rowOffsetBuf, - rowOffsetForCurG, - conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * sizeof(int32_t)); + memcpy(rowOffsetBuf, rowOffsetForCurG, ih_iw * sizeof(int32_t)); if (cpuinfo_has_x86_avx512f()) { // Currently use avx2 code @@ -1460,6 +1457,221 @@ void fbgemmGroupwiseConv( } } +} + +template < + typename packed_W, + typename outType, + typename processOutputType, + int SPATIAL_DIM> +void fbgemmGroupwiseConv( + const conv_param_t<SPATIAL_DIM>& conv_param, + const std::uint8_t* activations, + std::int32_t a_zero_point, + std::int32_t* rowOffsetBuf, + packed_W& packed_weights, + outType* out, + int32_t* outBuffer, + const processOutputType& outProcess, + int thread_id, + int num_threads) { + return fbgemmGroupwiseConvBase_< + packed_W, + outType, + processOutputType, + SPATIAL_DIM>( + conv_param, + activations, + a_zero_point, + rowOffsetBuf, + packed_weights, + out, + outBuffer, + outProcess, + thread_id, + num_threads); +} + +template < + typename packed_W, + typename outType, + bool FUSE_RELU, + QuantizationGranularity Q_GRAN, + int SPATIAL_DIM> +void fbgemmGroupwiseConv( + const conv_param_t<SPATIAL_DIM>& conv_param, + const std::uint8_t* activations, + std::int32_t a_zero_point, + std::int32_t* rowOffsetBuf, + packed_W& packed_weights, + outType* out, + int32_t* outBuffer, + const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess, + int thread_id, + int num_threads) { + typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN> processOutputType; + if (!fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param) || + (!cpuinfo_has_x86_avx512f() && !cpuinfo_has_x86_avx2())) { + return fbgemmGroupwiseConvBase_< + packed_W, + outType, + processOutputType, + SPATIAL_DIM>( + conv_param, + activations, + a_zero_point, + rowOffsetBuf, + packed_weights, + out, + outBuffer, + outProcess, + thread_id, + num_threads); + } + + int MB = conv_param.MB; + int H = conv_param.OUT_DIM[0]; + int W = conv_param.OUT_DIM[1]; + int G = conv_param.G; + int K_per_G = conv_param.OC / G; + int C_per_G = conv_param.IC / G; + int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; + int ih_iw = conv_param.IN_DIM[0] * conv_param.IN_DIM[1]; + + static_assert(SPATIAL_DIM == 2, "3D conv not supported yet"); + + int32_t* rowOffsetTrDest = rowOffsetBuf + 8 * ih_iw; + assert(G % 8 == 0); + // generate convolution kernel + jit_conv_kernel_fp fpConv = + getOrCreateConvKernel<>(conv_param, a_zero_point); + // generate row offset kernel + jit_rowoffset_kernel_fp fpRowoffset = + getOrCreateRowOffsetKernel(conv_param, a_zero_point); + for (int i = 0; i < MB; ++i) { + const uint8_t* actStartBatch = activations + i * ih_iw * conv_param.IC; + for (int gOuter = 0; gOuter < G; gOuter += 8) { + // for C_per_G == 4 and K_per_G == 4, row offset is calcualted for 8 + // groups at a time The result is row offsets in the format IH*IW x G + fpRowoffset( + actStartBatch + gOuter * C_per_G, + a_zero_point, + H, + W, + rowOffsetBuf); + // Transpose to get row offsets in the format G x IH*IW + internal::transpose_8x8( + ih_iw, + 8, + reinterpret_cast<const float*>(rowOffsetBuf), + 8, + reinterpret_cast<float*>(rowOffsetTrDest), + ih_iw); + int gLimit = gOuter + 8; + for (int g = gOuter; g < gLimit; g += 2) { + int32_t* currOutBuf = outBuffer + (g - gOuter) * K_per_G; + const uint8_t* actStartGroup = actStartBatch + g * C_per_G; + + fpConv( + actStartGroup, + packed_weights.getBuf() + g * K_per_G * C_per_G, + currOutBuf, + a_zero_point, + H, + W); + } + + bool b_symmetric = + outProcess.getBZeroPoint()[0] == 0 || rowOffsetBuf == nullptr; + + requantizationParams_t r = {a_zero_point, + outProcess.getBZeroPoint(), + outProcess.getCZeroPoint(), + outProcess.getCMultiplier(), + rowOffsetBuf, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.getNCols(), + G}; + + const std::int32_t* inp = outBuffer; + block_type_t block{i * oh_ow, oh_ow, gOuter * K_per_G, 8 * K_per_G}; + int ld_out = K_per_G * G; + int ld_in = K_per_G * G; + + if (a_zero_point == 0) { + if (b_symmetric) { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConv4Avx2< + true, + true, + Q_GRAN, + false, + FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConv4Avx2< + true, + true, + Q_GRAN, + true, + FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + } + } else { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConv4Avx2< + true, + false, + Q_GRAN, + false, + FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConv4Avx2< + true, + false, + Q_GRAN, + true, + FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + } + } + } else { + if (b_symmetric) { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConv4Avx2< + false, + true, + Q_GRAN, + false, + FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConv4Avx2< + false, + true, + Q_GRAN, + true, + FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + } + } else { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConv4Avx2< + false, + false, + Q_GRAN, + false, + FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConv4Avx2< + false, + false, + Q_GRAN, + true, + FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + } + } + } + } // gOuter loop + } // i loop +} + jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel( const conv_param_t<>& conv_param, int a_zero_point) { diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 3f85d89..7c36f6d 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -657,24 +657,341 @@ void requantizeOutputProcessingAvx2( } // i loop } -#define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ - template void \ - requantizeOutputProcessingAvx2<true, B_SYM, Q_GRAN, BIAS, RELU>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingAvx2<false, B_SYM, Q_GRAN, BIAS, RELU>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ +template < + bool A_SYMMETRIC, + bool B_SYMMETRIC, + QuantizationGranularity Q_GRAN, + bool HAS_BIAS, + bool FUSE_RELU> +void requantizeOutputProcessingGConv4Avx2( + uint8_t* out, + const int32_t* inp, + const block_type_t& block, + int ld_out, + int ld_in, + const requantizationParams_t& r) { + // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c + // using AVX2 instructions + int quant_param_idx = 0; + if (Q_GRAN == QuantizationGranularity::GROUP) { + int ncol_per_group = r.ncols / r.groups; + int g = block.col_start / ncol_per_group; + quant_param_idx = g; + } + __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + + __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); + __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); + + assert( + (A_SYMMETRIC == (r.A_zero_point == 0)) && + "A_SYMMETRIC == true if and only if A_zero_point == 0"); + assert( + (B_SYMMETRIC == + ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) || + r.row_offsets == nullptr)) && + "B_SYMMETRIC == true if and only if B_zero_point == 0 " + "or r.row_offsets == nullptr"); + assert( + (HAS_BIAS == (r.bias != nullptr)) && + "HAS_BIAS == true if and only if bias != nullptr"); + + __m256i A_zero_point_v = _mm256_set1_epi32(r.A_zero_point); + __m256i C_zero_point_epi16_v = _mm256_set1_epi16(r.C_zero_point); + __m256i C_zero_point_epi8_v = _mm256_set1_epi8(r.C_zero_point); + + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + + constexpr int VLEN = 8; + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + int j = block.col_start; + for (; j < block.col_start + (block.col_size / (VLEN * 4) * (VLEN * 4)); + j += (VLEN * 4)) { + __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + inp + (i - block.row_start) * ld_in + (j - block.col_start))); + __m256i y_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + inp + (i - block.row_start) * ld_in + (j - block.col_start) + + 1 * VLEN)); + __m256i z_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + inp + (i - block.row_start) * ld_in + (j - block.col_start) + + 2 * VLEN)); + __m256i w_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + inp + (i - block.row_start) * ld_in + (j - block.col_start) + + 3 * VLEN)); + + if (!A_SYMMETRIC) { + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.col_offsets + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.col_offsets + j + VLEN))); + y_v = _mm256_sub_epi32(y_v, col_off_v); + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + r.col_offsets + j + 2 * VLEN))); + z_v = _mm256_sub_epi32(z_v, col_off_v); + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>( + r.col_offsets + j + 3 * VLEN))); + w_v = _mm256_sub_epi32(w_v, col_off_v); + } + + if (!B_SYMMETRIC) { + // Load row_offsets for 2 groups and broadcast by 4 times each because + // we have 4 channels per group. + + // groups 0 and 1 + __m256i row_offset_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 0])), + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 1]), + 1); + __m256i B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[0]); + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.B_zero_point + j)); + } else if (Q_GRAN == QuantizationGranularity::GROUP) { + B_zero_point_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.B_zero_point[quant_param_idx])), + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 1]), + 1); + } + row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); + x_v = _mm256_sub_epi32(x_v, row_offset_v); + + // groups 2 and 3 + row_offset_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 2])), + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 3]), + 1); + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.B_zero_point + j + VLEN)); + } else if (Q_GRAN == QuantizationGranularity::GROUP) { + B_zero_point_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 2])), + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 3]), + 1); + } + row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); + y_v = _mm256_sub_epi32(y_v, row_offset_v); + + // groups 4 and 5 + row_offset_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 4])), + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 5]), + 1); + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.B_zero_point + j + 2 * VLEN)); + } else if (Q_GRAN == QuantizationGranularity::GROUP) { + B_zero_point_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 4])), + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 5]), + 1); + } + row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); + z_v = _mm256_sub_epi32(z_v, row_offset_v); + + // groups 6 and 7 + row_offset_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 6])), + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 7]), + 1); + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.B_zero_point + j + 3 * VLEN)); + } else if (Q_GRAN == QuantizationGranularity::GROUP) { + B_zero_point_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 6])), + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 7]), + 1); + } + row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); + w_v = _mm256_sub_epi32(w_v, row_offset_v); + } + if (HAS_BIAS) { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + } + + /* + * Convert int32_t input to FP32 and multiply by FP32 scale. + * Both operations involve statistically unbiased roundings (with + * default MXCSR rounding mode): + * - Large int32_t values can't be exactly represented as FP32. + * CVTDQ2PS instruction on x86 would round it according to nearest + * FP32 value with ties to even (assuming default MXCSR rounding + * mode). + * - Product of two FP32 values is generally not exactly + * representation as an FP32 value, and will be rounded to nearest + * FP32 value with ties to even with default MXCSR rounding mode. + */ + __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_scaled_v = _mm256_mul_ps( + _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); + y_scaled_v = _mm256_mul_ps( + _mm256_cvtepi32_ps(y_v), + _mm256_loadu_ps(r.C_multiplier + j + VLEN)); + z_scaled_v = _mm256_mul_ps( + _mm256_cvtepi32_ps(z_v), + _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = _mm256_mul_ps( + _mm256_cvtepi32_ps(w_v), + _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + } else if (Q_GRAN == QuantizationGranularity::GROUP) { + multiplier_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.C_multiplier[quant_param_idx])), + _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]), + 1); + x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + + multiplier_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.C_multiplier[quant_param_idx + 2])), + _mm_set1_ps(r.C_multiplier[quant_param_idx + 3]), + 1); + y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + + multiplier_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.C_multiplier[quant_param_idx + 4])), + _mm_set1_ps(r.C_multiplier[quant_param_idx + 5]), + 1); + z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + + multiplier_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.C_multiplier[quant_param_idx + 6])), + _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]), + 1); + w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + } else { + x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + } + + /* + * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction. + * CVTPS2DQ instruction rounds result according to nearest FP32 value + * with ties to even (assuming default MXCSR rounding mode). However, + * when conversion overflows, it produces INT32_MIN as a result. For + * large positive inputs the result of conversion can become negative, + * which affects the final requantization result. Note that on x86 + * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This + * happens because float(INT32_MAX) rounds to 2**31, which overflows + * int32_t when it is converted back to integer. + * + * Thankfully, we can prove that overflow never happens in this + * requantization scheme. The largest positive input is INT32_MAX + * (2**31 - 1), which turns into 2**31 when converted to float. The + * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the + * result is 2147483520 (compare to INT32_MAX = 2147483647), which + * fits into int32_t without overflow. + */ + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); + + /* + * Standard final sequence on x86 AVX2: + * - Pack to int16_t and saturate + * - Add zero point + * - Pack to uint8_t and saturate + * - Clamp between qmin and qmax + */ + __m256i xy_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); + __m256i zw_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); + __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); + __m256i xyzw_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(xyzw_packed_v, max_v)); + + /* + * xyzw_clamped_v has results in the following layout so we need to + * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 + */ + xyzw_clamped_v = + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + + /* + * 4x CVTDQ2PS + * 4x MULPS + * 4x CVTPS2DQ + * 2x PACKSSDW + * 1x PACKUSWB + * 2x PADDW + * 1x PMAXUB + * 1x PMINUB + * 1x PERMD + * --------------------- + * 20 instructions total + */ + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(out + i * ld_out + j), xyzw_clamped_v); + } // j loop vectorized and unrolled 4x + + int remainder = block.col_start + block.col_size - j; + assert(remainder == 0); + } // i loop +} + +#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ + template void \ + requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t& r); \ + template void \ + requantizeOutputProcessingGConv4Avx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ const requantizationParams_t& r); +#define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ + INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \ + INSTANTIATE_REQUANTIZE(false, B_SYM, Q_GRAN, BIAS, RELU) + #define INSTANTIATE_B_SYM(Q_GRAN, BIAS, RELU) \ INSTANTIATE_A_SYM(true, Q_GRAN, BIAS, RELU) \ INSTANTIATE_A_SYM(false, Q_GRAN, BIAS, RELU) |