Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-01-31 22:02:39 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-31 22:34:30 +0300
commita98f5d8ffe070b0124dcd34963e6ae55c7864407 (patch)
treeba7678c579f2556c0b356ee9bb514434db5d5b04 /src
parent1c3685adb3aff4241d83da0c62f94f4f4bd37511 (diff)
optimize requantization remainder (#64)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/64 Use mask instead of scalar code Reviewed By: dskhudia Differential Revision: D13893809 fbshipit-source-id: 8e33c85d65b2dcf0cdb8e92372c44dcc9bcf6824
Diffstat (limited to 'src')
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc10
-rw-r--r--src/QuantUtilsAvx2.cc87
2 files changed, 77 insertions, 20 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index 00a9571..8e851b2 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -1374,9 +1374,9 @@ void fbgemmGroupwiseConv(
internal::transpose_8x8(
conv_param.IN_DIM[0] * conv_param.IN_DIM[1],
8,
- (const float*)rowOffsetBuf,
+ reinterpret_cast<const float*>(rowOffsetBuf),
8,
- (float*)rowOffsetTrDest,
+ reinterpret_cast<float*>(rowOffsetTrDest),
conv_param.IN_DIM[0] * conv_param.IN_DIM[1]);
int gLimit = gOuter + 8;
for (int g = gOuter; g < gLimit; g += 2) {
@@ -1429,9 +1429,9 @@ void fbgemmGroupwiseConv(
assert(0 && "unsupported architecure");
}
} // j loop
- }
- }
- }
+ } // g loop
+ } // gOuter loop
+ } // i loop
} else {
// for the not supported cases, just execute the naive C implementation
conv_ref(
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc
index 3905d65..3f85d89 100644
--- a/src/QuantUtilsAvx2.cc
+++ b/src/QuantUtilsAvx2.cc
@@ -571,31 +571,88 @@ void requantizeOutputProcessingAvx2(
_mm256_castsi256_si128(x_clamped_v));
} // j loop vectorized
- // TODO: vectorize remainder using masking
- for (; j < block.col_start + block.col_size; ++j) {
- int32_t raw = inp[(i - block.row_start) * ld_in + (j - block.col_start)];
+ int remainder = block.col_start + block.col_size - j;
+ if (remainder > 0) {
+ alignas(64) const int masks[8][8] = {
+ // NOTE: clang-format wants to use a different formatting but the
+ // current formatting should be easier to read.
+ { 0, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, -1, 0, },
+ };
+ __m256i mask_v = _mm256_load_si256(
+ reinterpret_cast<const __m256i*>(masks[remainder]));
+
+ __m256i x_v = _mm256_maskload_epi32(
+ inp + (i - block.row_start) * ld_in + (j - block.col_start),
+ mask_v);
+
if (!A_SYMMETRIC) {
- raw -= r.A_zero_point * r.col_offsets[j];
+ __m256i col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v, _mm256_maskload_epi32(r.col_offsets + j, mask_v));
+ x_v = _mm256_sub_epi32(x_v, col_off_v);
}
+
if (!B_SYMMETRIC) {
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- row_offset = r.row_offsets[i - block.row_start] * r.B_zero_point[j];
+ row_offset_v = _mm256_mullo_epi32(
+ _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
+ _mm256_maskload_epi32(r.B_zero_point + j, mask_v));
}
- raw -= row_offset;
+ x_v = _mm256_sub_epi32(x_v, row_offset_v);
}
if (HAS_BIAS) {
- raw += r.bias[j];
+ x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v));
}
- float ab = raw *
- ((Q_GRAN == QuantizationGranularity::OUT_CHANNEL)
- ? r.C_multiplier[j]
- : r.C_multiplier[quant_param_idx]);
- long rounded = std::lrintf(ab) + r.C_zero_point;
+ __m256 x_scaled_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_scaled_v = _mm256_mul_ps(
+ _mm256_cvtepi32_ps(x_v),
+ _mm256_maskload_ps(r.C_multiplier + j, mask_v));
+ } else {
+ x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ }
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+
+ __m256i x_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
+ C_zero_point_epi16_v);
+ x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
+ __m256i x_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(x_packed_v, max_v));
- out[i * ld_out + j] = std::max(
- FUSE_RELU ? static_cast<long>(r.C_zero_point) : 0l,
- std::min(255l, rounded));
+ /*
+ * x_clamped_v has results in the following layout so we need to
+ * permute: x0-3 garbage0-11 x4-7 garbage12-23
+ */
+ x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
+
+ /*
+ * 1x CVTDQ2PS
+ * 1x MULPS
+ * 1x CVTPS2DQ
+ * 1x PACKSSDW
+ * 1x PACKUSWB
+ * 1x PADDW
+ * 1x PMAXUB
+ * 1x PMINUB
+ * 1x PERMD
+ * ---------------------
+ * 9 instructions total
+ */
+ alignas(64) uint8_t x_clamped_buffer[32];
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(x_clamped_buffer), x_clamped_v);
+ for (int k = 0; k < remainder; ++k) {
+ out[i * ld_out + j + k] = x_clamped_buffer[k];
+ }
} // j loop remainder
} // i loop
}