Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/PackMatrix.cc')
-rw-r--r--src/PackMatrix.cc86
1 files changed, 86 insertions, 0 deletions
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
new file mode 100644
index 0000000..85000ac
--- /dev/null
+++ b/src/PackMatrix.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <iomanip>
+#include <stdexcept>
+#include <type_traits>
+#include "fbgemm/ConvUtils.h"
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename PT, typename inpType, typename accType>
+PackMatrix<PT, inpType, accType>::PackMatrix(
+ int32_t rows,
+ int32_t cols,
+ inpType* buf,
+ int32_t zero_pt)
+ : buf_(buf), nrows_(rows), ncols_(cols), zero_pt_(zero_pt) {
+ bufAllocatedHere_ = false;
+ if (!cpuinfo_initialize()) {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template <typename PT, typename inpType, typename accType>
+int PackMatrix<PT, inpType, accType>::packedBufferSize(int rows, int cols) {
+ if (cpuinfo_has_x86_avx512f()) {
+ if (isA) {
+ return PackingTraits<inpType, accType, inst_set_t::avx512>::MCB *
+ PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
+ } else {
+ int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
+ int colBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
+ return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
+ (((cols + colBlock - 1) / colBlock) * colBlock);
+ }
+ } else if (cpuinfo_has_x86_avx2()) {
+ if (isA) {
+ return PackingTraits<inpType, accType, inst_set_t::avx2>::MCB *
+ PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
+ } else {
+ int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
+ int colBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
+ return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
+ (((cols + colBlock - 1) / colBlock) * colBlock);
+ }
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ return -1;
+}
+
+// int32 accumulation
+template class PackMatrix<PackAMatrix<uint8_t, int32_t>, uint8_t, int32_t>;
+
+template class PackMatrix<
+ PackAWithRowOffset<uint8_t, int32_t>,
+ uint8_t,
+ int32_t>;
+
+template class PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>;
+
+template class PackMatrix<
+ PackAWithQuantRowOffset<uint8_t, int32_t>,
+ uint8_t,
+ int32_t>;
+
+template class PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>;
+
+// int16 accumulation
+template class PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>;
+
+template class PackMatrix<
+ PackAWithRowOffset<uint8_t, int16_t>,
+ uint8_t,
+ int16_t>;
+
+template class PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>;
+
+template class PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>;
+} // namespace fbgemm2