diff options
author | Jongsoo Park <jongsoo@fb.com> | 2019-01-31 22:02:39 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-31 22:34:30 +0300 |
commit | a98f5d8ffe070b0124dcd34963e6ae55c7864407 (patch) | |
tree | ba7678c579f2556c0b356ee9bb514434db5d5b04 /src | |
parent | 1c3685adb3aff4241d83da0c62f94f4f4bd37511 (diff) |
optimize requantization remainder (#64)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/64
Use mask instead of scalar code
Reviewed By: dskhudia
Differential Revision: D13893809
fbshipit-source-id: 8e33c85d65b2dcf0cdb8e92372c44dcc9bcf6824
Diffstat (limited to 'src')
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 10 | ||||
-rw-r--r-- | src/QuantUtilsAvx2.cc | 87 |
2 files changed, 77 insertions, 20 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index 00a9571..8e851b2 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -1374,9 +1374,9 @@ void fbgemmGroupwiseConv( internal::transpose_8x8( conv_param.IN_DIM[0] * conv_param.IN_DIM[1], 8, - (const float*)rowOffsetBuf, + reinterpret_cast<const float*>(rowOffsetBuf), 8, - (float*)rowOffsetTrDest, + reinterpret_cast<float*>(rowOffsetTrDest), conv_param.IN_DIM[0] * conv_param.IN_DIM[1]); int gLimit = gOuter + 8; for (int g = gOuter; g < gLimit; g += 2) { @@ -1429,9 +1429,9 @@ void fbgemmGroupwiseConv( assert(0 && "unsupported architecure"); } } // j loop - } - } - } + } // g loop + } // gOuter loop + } // i loop } else { // for the not supported cases, just execute the naive C implementation conv_ref( diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 3905d65..3f85d89 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -571,31 +571,88 @@ void requantizeOutputProcessingAvx2( _mm256_castsi256_si128(x_clamped_v)); } // j loop vectorized - // TODO: vectorize remainder using masking - for (; j < block.col_start + block.col_size; ++j) { - int32_t raw = inp[(i - block.row_start) * ld_in + (j - block.col_start)]; + int remainder = block.col_start + block.col_size - j; + if (remainder > 0) { + alignas(64) const int masks[8][8] = { + // NOTE: clang-format wants to use a different formatting but the + // current formatting should be easier to read. + { 0, 0, 0, 0, 0, 0, 0, 0, }, + { -1, 0, 0, 0, 0, 0, 0, 0, }, + { -1, -1, 0, 0, 0, 0, 0, 0, }, + { -1, -1, -1, 0, 0, 0, 0, 0, }, + { -1, -1, -1, -1, 0, 0, 0, 0, }, + { -1, -1, -1, -1, -1, 0, 0, 0, }, + { -1, -1, -1, -1, -1, -1, 0, 0, }, + { -1, -1, -1, -1, -1, -1, -1, 0, }, + }; + __m256i mask_v = _mm256_load_si256( + reinterpret_cast<const __m256i*>(masks[remainder])); + + __m256i x_v = _mm256_maskload_epi32( + inp + (i - block.row_start) * ld_in + (j - block.col_start), + mask_v); + if (!A_SYMMETRIC) { - raw -= r.A_zero_point * r.col_offsets[j]; + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, _mm256_maskload_epi32(r.col_offsets + j, mask_v)); + x_v = _mm256_sub_epi32(x_v, col_off_v); } + if (!B_SYMMETRIC) { if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - row_offset = r.row_offsets[i - block.row_start] * r.B_zero_point[j]; + row_offset_v = _mm256_mullo_epi32( + _mm256_set1_epi32(r.row_offsets[i - block.row_start]), + _mm256_maskload_epi32(r.B_zero_point + j, mask_v)); } - raw -= row_offset; + x_v = _mm256_sub_epi32(x_v, row_offset_v); } if (HAS_BIAS) { - raw += r.bias[j]; + x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v)); } - float ab = raw * - ((Q_GRAN == QuantizationGranularity::OUT_CHANNEL) - ? r.C_multiplier[j] - : r.C_multiplier[quant_param_idx]); - long rounded = std::lrintf(ab) + r.C_zero_point; + __m256 x_scaled_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_scaled_v = _mm256_mul_ps( + _mm256_cvtepi32_ps(x_v), + _mm256_maskload_ps(r.C_multiplier + j, mask_v)); + } else { + x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + } + __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); + + __m256i x_packed_v = _mm256_adds_epi16( + _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), + C_zero_point_epi16_v); + x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); + __m256i x_clamped_v = _mm256_max_epu8( + FUSE_RELU ? C_zero_point_epi8_v : min_v, + _mm256_min_epu8(x_packed_v, max_v)); - out[i * ld_out + j] = std::max( - FUSE_RELU ? static_cast<long>(r.C_zero_point) : 0l, - std::min(255l, rounded)); + /* + * x_clamped_v has results in the following layout so we need to + * permute: x0-3 garbage0-11 x4-7 garbage12-23 + */ + x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); + + /* + * 1x CVTDQ2PS + * 1x MULPS + * 1x CVTPS2DQ + * 1x PACKSSDW + * 1x PACKUSWB + * 1x PADDW + * 1x PMAXUB + * 1x PMINUB + * 1x PERMD + * --------------------- + * 9 instructions total + */ + alignas(64) uint8_t x_clamped_buffer[32]; + _mm256_store_si256( + reinterpret_cast<__m256i*>(x_clamped_buffer), x_clamped_v); + for (int k = 0; k < remainder; ++k) { + out[i * ld_out + j + k] = x_clamped_buffer[k]; + } } // j loop remainder } // i loop } |