Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/PackAWithIm2Col.cc')
-rw-r--r--src/PackAWithIm2Col.cc146
1 files changed, 146 insertions, 0 deletions
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
new file mode 100644
index 0000000..7012289
--- /dev/null
+++ b/src/PackAWithIm2Col.cc
@@ -0,0 +1,146 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <iomanip>
+#include <iostream>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm2 {
+
+template <typename T, typename accT>
+PackAWithIm2Col<T, accT>::PackAWithIm2Col(
+ const conv_param_t& conv_p,
+ const T* sdata,
+ inpType* pmat,
+ int32_t zero_pt,
+ int32_t* row_offset)
+ : PackMatrix<PackAWithIm2Col<T, accT>, T, accT>(
+ conv_p.MB * conv_p.OH * conv_p.OW,
+ conv_p.KH * conv_p.KW * conv_p.IC,
+ pmat,
+ zero_pt),
+ conv_p_(conv_p),
+ sdata_(sdata) {
+ assert(conv_p.G == 1 && "Groups != 1 not supported yet");
+
+ if (cpuinfo_has_x86_avx512f()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE;
+ } else if (cpuinfo_has_x86_avx2()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ if (pmat) {
+ BaseType::buf_ = pmat;
+ } else {
+ BaseType::bufAllocatedHere_ = true;
+ BaseType::buf_ = static_cast<T*>(
+ aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
+ }
+ if (row_offset) {
+ row_offset_ = row_offset;
+ } else {
+ rowOffsetAllocatedHere = true;
+ row_offset_ = static_cast<int32_t*>(
+ aligned_alloc(64, BaseType::brow_ * sizeof(int32_t)));
+ }
+}
+
+template <typename T, typename accT>
+void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
+ block_type_t block_p = {block.row_start,
+ block.row_size,
+ block.col_start,
+ (block.col_size + row_interleave_B_ - 1) /
+ row_interleave_B_ * row_interleave_B_};
+
+ BaseType::packedBlock(block_p);
+ T* out = BaseType::getBuf();
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int n = i / (conv_p_.OH * conv_p_.OW);
+ int hw = i % (conv_p_.OH * conv_p_.OW);
+ int w = hw % conv_p_.OW;
+ int h = hw / conv_p_.OW;
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ int c = j % conv_p_.IC;
+ int rs = j / conv_p_.IC;
+ int s = rs % conv_p_.KW;
+ int r = rs / conv_p_.KW;
+
+ int w_in = -conv_p_.pad_w + w * conv_p_.stride_w + s;
+ int h_in = -conv_p_.pad_h + h * conv_p_.stride_h + r;
+ // Please note that padding for convolution should be filled with zero_pt
+ if (h_in < 0 || h_in >= conv_p_.IH || w_in < 0 || w_in >= conv_p_.IW) {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = BaseType::zeroPoint();
+ } else {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = sdata_
+ [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC + c];
+ }
+ }
+ // zero fill
+ // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill.
+ for (int j = block.col_start + block.col_size;
+ j < block_p.col_start + block_p.col_size;
+ ++j) {
+ out[(i - block.row_start) * BaseType::blockColSize() +
+ (j - block.col_start)] = 0;
+ }
+ }
+}
+
+template <typename T, typename accT>
+void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) {
+ std::cout << name << ":"
+ << "[" << BaseType::numPackedRows() << ", "
+ << BaseType::numPackedCols() << "]" << std::endl;
+
+ T* out = BaseType::getBuf();
+ for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
+ for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
+ T val = out[ r * BaseType::blockColSize() + c ];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::endl;
+}
+
+template <typename T, typename accT>
+int PackAWithIm2Col<T, accT>::rowOffsetBufferSize() {
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
+ } else if (cpuinfo_has_x86_avx2()) {
+ return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return -1;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template class PackAWithIm2Col<uint8_t, int32_t>;
+template class PackAWithIm2Col<uint8_t, int16_t>;
+} // namespace fbgemm2