Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authordskhudia <dskhudia@fb.com>2018-11-06 00:17:52 +0300
committerdskhudia <dskhudia@fb.com>2018-11-06 00:17:52 +0300
commitb96bc0bf311f7abdc83ffd3af0a485b4aef53f7c (patch)
tree2a6c276d20753abe94c526aab7b109305e3d1d78 /src
parent14adee1ac506e067489406af689ae9b73fb581bd (diff)
generalized conv_param_t and download third party libraries in build dir
Diffstat (limited to 'src')
-rw-r--r--src/ExecuteKernelU8S8.cc12
-rw-r--r--src/Fbgemm.cc20
-rw-r--r--src/PackAWithIm2Col.cc143
-rw-r--r--src/PackMatrix.cc8
-rw-r--r--src/RefImplementations.cc214
-rw-r--r--src/RefImplementations.h22
6 files changed, 348 insertions, 71 deletions
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index 5145869..e091a87 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -322,6 +322,12 @@ template class ExecuteKernel<
memCopy<>>;
template class ExecuteKernel<
+ PackAWithIm2Col<uint8_t, int16_t, 3>,
+ PackBMatrix<int8_t, int16_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
PackAWithRowOffset<uint8_t, int32_t>,
PackBMatrix<int8_t, int32_t>,
int32_t,
@@ -334,6 +340,12 @@ template class ExecuteKernel<
memCopy<>>;
template class ExecuteKernel<
+ PackAWithIm2Col<uint8_t, int32_t, 3>,
+ PackBMatrix<int8_t, int32_t>,
+ int32_t,
+ memCopy<>>;
+
+template class ExecuteKernel<
PackAWithQuantRowOffset<uint8_t, int32_t>,
PackBMatrix<int8_t, int32_t>,
int32_t,
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index f3bac97..9195a05 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -236,6 +236,16 @@ template void fbgemmPacked(
int num_threads);
template void fbgemmPacked(
+ PackMatrix<PackAWithIm2Col<uint8_t, int32_t, 3>, uint8_t, int32_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>&
packA,
PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB,
@@ -341,6 +351,16 @@ template void fbgemmPacked(
int num_threads);
template void fbgemmPacked(
+ PackMatrix<PackAWithIm2Col<uint8_t, int16_t, 3>, uint8_t, int16_t>& packA,
+ PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
+ int32_t* C,
+ int32_t* C_buffer,
+ uint32_t ldc,
+ const memCopy<>& outProcess,
+ int thread_id,
+ int num_threads);
+
+template void fbgemmPacked(
PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA,
PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
int32_t* C,
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index e067a3e..8dde696 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -4,26 +4,37 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
-#include <cpuinfo.h>
+#include <algorithm>
#include <cassert>
#include <iomanip>
#include <iostream>
-#include "fbgemm/Fbgemm.h"
+#include <numeric>
+#include <cpuinfo.h>
-#include <algorithm>
+#include "fbgemm/Fbgemm.h"
namespace fbgemm2 {
-template <typename T, typename accT>
-PackAWithIm2Col<T, accT>::PackAWithIm2Col(
- const conv_param_t& conv_p,
+template <typename T, typename accT, int SPATIAL_DIM>
+PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
+ const conv_param_t<SPATIAL_DIM>& conv_p,
const T* sdata,
inpType* pmat,
int32_t zero_pt,
int32_t* row_offset)
- : PackMatrix<PackAWithIm2Col<T, accT>, T, accT>(
- conv_p.MB * conv_p.OH * conv_p.OW,
- conv_p.KH * conv_p.KW * conv_p.IC,
+ : PackMatrix<PackAWithIm2Col<T, accT, SPATIAL_DIM>, T, accT>(
+ conv_p.MB *
+ std::accumulate(
+ conv_p.OUT_DIM.begin(),
+ conv_p.OUT_DIM.end(),
+ 1,
+ std::multiplies<int>()),
+ std::accumulate(
+ conv_p.K.begin(),
+ conv_p.K.end(),
+ 1,
+ std::multiplies<int>()) *
+ conv_p.IC,
pmat,
zero_pt),
conv_p_(conv_p),
@@ -62,8 +73,8 @@ PackAWithIm2Col<T, accT>::PackAWithIm2Col(
}
}
-template <typename T, typename accT>
-void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
+template <typename T, typename accT, int SPATIAL_DIM>
+void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
block_type_t block_p = {block.row_start,
block.row_size,
block.col_start,
@@ -72,11 +83,87 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
BaseType::packedBlock(block_p);
T* out = BaseType::getBuf();
+ if (SPATIAL_DIM == 3) { // static if
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ int n =
+ i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]);
+ int thw =
+ i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]);
+ int w = thw % conv_p_.OUT_DIM[2];
+ int h = thw / conv_p_.OUT_DIM[2] % conv_p_.OUT_DIM[1];
+ int t = thw / conv_p_.OUT_DIM[2] / conv_p_.OUT_DIM[1];
+ for (int j = block.col_start;
+ j < block.col_start + block.col_size + conv_p_.IC - 1;
+ j += conv_p_.IC) {
+ int j_blk_id = j / conv_p_.IC;
+ // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC )
+ int j_blk_start = std::max(j_blk_id * conv_p_.IC, block.col_start);
+ int j_blk_end = std::min(
+ (j_blk_id + 1) * conv_p_.IC, block.col_start + block.col_size);
+ if (j_blk_start >= j_blk_end) {
+ break;
+ }
+
+ int qrs = j / conv_p_.IC;
+ int s = qrs % conv_p_.K[2];
+ int r = qrs / conv_p_.K[2] % conv_p_.K[1];
+ int q = qrs / conv_p_.K[2] / conv_p_.K[1];
+
+ int t_in = -conv_p_.pad[0] + t * conv_p_.stride[0] + q;
+ int h_in = -conv_p_.pad[1] + h * conv_p_.stride[1] + r;
+ int w_in = -conv_p_.pad[2] + w * conv_p_.stride[2] + s;
+
+ if (t_in < 0 || t_in >= conv_p_.IN_DIM[0] || h_in < 0 ||
+ h_in >= conv_p_.IN_DIM[1] || w_in < 0 ||
+ w_in >= conv_p_.IN_DIM[2]) {
+ // Please note that padding for convolution should be filled with
+ // zero_pt
+ std::memset(
+ &out
+ [(i - block.row_start) * BaseType::blockColSize() +
+ (j_blk_start - block.col_start)],
+ BaseType::zeroPoint(),
+ sizeof(T) * (j_blk_end - j_blk_start));
+ } else {
+ std::memcpy(
+ &out
+ [(i - block.row_start) * BaseType::blockColSize() +
+ j_blk_start - block.col_start],
+ &sdata_
+ [(((n * conv_p_.IN_DIM[0] + t_in) * conv_p_.IN_DIM[1] +
+ h_in) *
+ conv_p_.IN_DIM[2] +
+ w_in) *
+ conv_p_.IC +
+ (j_blk_start % conv_p_.IC)],
+ sizeof(T) * (j_blk_end - j_blk_start));
+ }
+ }
+ // zero fill
+ // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill.
+ if ((block_p.col_start + block_p.col_size) -
+ (block.col_start + block.col_size) >
+ 0) {
+ std::memset(
+ &out
+ [(i - block.row_start) * BaseType::blockColSize() +
+ (block.col_size)],
+ 0,
+ sizeof(T) *
+ ((block_p.col_start + block_p.col_size) -
+ (block.col_start + block.col_size)));
+ }
+ }
+ return;
+ }
+
+ assert(SPATIAL_DIM == 2 && "unsupported conv dimension");
+
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
- int n = i / (conv_p_.OH * conv_p_.OW);
- int hw = i % (conv_p_.OH * conv_p_.OW);
- int w = hw % conv_p_.OW;
- int h = hw / conv_p_.OW;
+ int n = i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]);
+ int hw = i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]);
+ int w = hw % conv_p_.OUT_DIM[1];
+ int h = hw / conv_p_.OUT_DIM[1];
for (int j = block.col_start;
j < block.col_start + block.col_size + conv_p_.IC - 1;
j += conv_p_.IC) {
@@ -90,13 +177,14 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
}
int rs = j / conv_p_.IC;
- int s = rs % conv_p_.KW;
- int r = rs / conv_p_.KW;
+ int s = rs % conv_p_.K[1];
+ int r = rs / conv_p_.K[1];
- int w_in = -conv_p_.pad_w + w * conv_p_.stride_w + s;
- int h_in = -conv_p_.pad_h + h * conv_p_.stride_h + r;
+ int h_in = -conv_p_.pad[0] + h * conv_p_.stride[0] + r;
+ int w_in = -conv_p_.pad[1] + w * conv_p_.stride[1] + s;
- if (h_in < 0 || h_in >= conv_p_.IH || w_in < 0 || w_in >= conv_p_.IW) {
+ if (h_in < 0 || h_in >= conv_p_.IN_DIM[0] || w_in < 0 ||
+ w_in >= conv_p_.IN_DIM[1]) {
// Please note that padding for convolution should be filled with
// zero_pt
std::memset(
@@ -111,7 +199,8 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
[(i - block.row_start) * BaseType::blockColSize() +
j_blk_start - block.col_start],
&sdata_
- [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC +
+ [((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] + w_in) *
+ conv_p_.IC +
(j_blk_start % conv_p_.IC)],
sizeof(T) * (j_blk_end - j_blk_start));
}
@@ -133,8 +222,9 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) {
}
}
-template <typename T, typename accT>
-void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) {
+template <typename T, typename accT, int SPATIAL_DIM>
+void PackAWithIm2Col<T, accT, SPATIAL_DIM>::printPackedMatrix(
+ std::string name) {
std::cout << name << ":"
<< "[" << BaseType::numPackedRows() << ", "
<< BaseType::numPackedCols() << "]" << std::endl;
@@ -155,8 +245,8 @@ void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) {
std::cout << std::endl;
}
-template <typename T, typename accT>
-int PackAWithIm2Col<T, accT>::rowOffsetBufferSize() {
+template <typename T, typename accT, int SPATIAL_DIM>
+int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize() {
if (cpuinfo_initialize()) {
if (cpuinfo_has_x86_avx512f()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
@@ -174,4 +264,7 @@ int PackAWithIm2Col<T, accT>::rowOffsetBufferSize() {
template class PackAWithIm2Col<uint8_t, int32_t>;
template class PackAWithIm2Col<uint8_t, int16_t>;
+template class PackAWithIm2Col<uint8_t, int32_t, 3>;
+template class PackAWithIm2Col<uint8_t, int16_t, 3>;
+
} // namespace fbgemm2
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
index 85000ac..37b4e88 100644
--- a/src/PackMatrix.cc
+++ b/src/PackMatrix.cc
@@ -64,6 +64,10 @@ template class PackMatrix<
int32_t>;
template class PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>;
+template class PackMatrix<
+ PackAWithIm2Col<uint8_t, int32_t, 3>,
+ uint8_t,
+ int32_t>;
template class PackMatrix<
PackAWithQuantRowOffset<uint8_t, int32_t>,
@@ -74,6 +78,10 @@ template class PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>;
// int16 accumulation
template class PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>;
+template class PackMatrix<
+ PackAWithIm2Col<uint8_t, int16_t, 3>,
+ uint8_t,
+ int16_t>;
template class PackMatrix<
PackAWithRowOffset<uint8_t, int16_t>,
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index 6bf2d65..4b919c1 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -219,10 +219,10 @@ void spmdm_ref(
}
int32_t clip_16bit(int32_t x) {
- if (x > std::numeric_limits<int16_t>::max()) {
- return std::min<int>(std::numeric_limits<int16_t>::max(), x);
- } else if (x < std::numeric_limits<int16_t>::min()) {
- return std::max<int>(std::numeric_limits<int16_t>::min(), x);
+ if (x > numeric_limits<int16_t>::max()) {
+ return std::min<int>(numeric_limits<int16_t>::max(), x);
+ } else if (x < numeric_limits<int16_t>::min()) {
+ return std::max<int>(numeric_limits<int16_t>::min(), x);
} else {
return x;
}
@@ -235,36 +235,38 @@ int32_t clip_16bit(int32_t x) {
* Ao: NHWC: NH_1W_1 x RSC_0
*/
void im2col_ref(
- const conv_param_t& conv_p,
- const std::uint8_t* A,
- std::int32_t A_zero_point,
- std::uint8_t* Ao) {
+ const conv_param_t<>& conv_p,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ uint8_t* Ao) {
for (int n = 0; n < conv_p.MB; ++n) {
- for (int h = 0; h < conv_p.OH; ++h) {
- for (int w = 0; w < conv_p.OW; ++w) {
- for (int r = 0; r < conv_p.KH; ++r) {
- int h_in = -conv_p.pad_h + h * conv_p.stride_h + r;
- for (int s = 0; s < conv_p.KW; ++s) {
- int w_in = -conv_p.pad_w + w * conv_p.stride_w + s;
- if (h_in < 0 || h_in >= conv_p.IH || w_in < 0 ||
- w_in >= conv_p.IW) {
- std::memset(
- &Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) *
- conv_p.KW +
+ for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) {
+ for (int r = 0; r < conv_p.K[0]; ++r) {
+ int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
+ for (int s = 0; s < conv_p.K[1]; ++s) {
+ int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
+ if (h_in < 0 || h_in >= conv_p.IN_DIM[0] || w_in < 0 ||
+ w_in >= conv_p.IN_DIM[1]) {
+ memset(
+ &Ao[((((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.K[0] +
+ r) *
+ conv_p.K[1] +
s) *
- conv_p.IC +
- 0],
+ conv_p.IC],
A_zero_point,
sizeof(uint8_t) * conv_p.IC);
} else {
- std::memcpy(
- &Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) *
- conv_p.KW +
+ memcpy(
+ &Ao[((((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.K[0] +
+ r) *
+ conv_p.K[1] +
s) *
- conv_p.IC +
- 0],
- &A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC +
- 0],
+ conv_p.IC],
+ &A[((n * conv_p.IN_DIM[0] + h_in) * conv_p.IN_DIM[1] + w_in) *
+ conv_p.IC],
sizeof(uint8_t) * conv_p.IC);
}
} // for each s
@@ -274,44 +276,168 @@ void im2col_ref(
} // for each n
}
+/* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function
+ * from caffe2/utils/math_cpu.cc
+ * NHWC StorageOrder/Layout
+ * A: NHWC: NT_0H_0W_0 x C_0
+ * Ao: NHWC: NT_1H_1W_1 x QRSC_0
+ */
+void im2col3d_ref(
+ const conv_param_t<3>& conv_p,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ uint8_t* Ao) {
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int t = 0; t < conv_p.OUT_DIM[0]; ++t) {
+ for (int h = 0; h < conv_p.OUT_DIM[1]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[2]; ++w) {
+ for (int q = 0; q < conv_p.K[0]; ++q) {
+ int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q;
+ for (int r = 0; r < conv_p.K[1]; ++r) {
+ int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r;
+ for (int s = 0; s < conv_p.K[2]; ++s) {
+ int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s;
+ if (t_in < 0 || t_in >= conv_p.IN_DIM[0] || h_in < 0 ||
+ h_in >= conv_p.IN_DIM[1] || w_in < 0 ||
+ w_in >= conv_p.IN_DIM[2]) {
+ memset(
+ &Ao[((((((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] +
+ h) *
+ conv_p.OUT_DIM[2] +
+ w) *
+ conv_p.K[0] +
+ q) *
+ conv_p.K[1] +
+ r) *
+ conv_p.K[2] +
+ s) *
+ conv_p.IC],
+ A_zero_point,
+ sizeof(uint8_t) * conv_p.IC);
+ } else {
+ memcpy(
+ &Ao[((((((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] +
+ h) *
+ conv_p.OUT_DIM[2] +
+ w) *
+ conv_p.K[0] +
+ q) *
+ conv_p.K[1] +
+ r) *
+ conv_p.K[2] +
+ s) *
+ conv_p.IC],
+ &A[(((n * conv_p.IN_DIM[0] + t_in) * conv_p.IN_DIM[1] +
+ h_in) *
+ conv_p.IN_DIM[2] +
+ w_in) *
+ conv_p.IC],
+ sizeof(uint8_t) * conv_p.IC);
+ }
+ } // for each s
+ } // for each r
+ } // for each q
+ } // for each w
+ } // for each h
+ } // for each t
+ } // for each n
+}
+
void conv_ref(
- const conv_param_t& conv_p,
- const std::uint8_t* A,
- std::int32_t A_zero_point,
- const std::int8_t* B,
- std::int32_t* C) {
+ const conv_param_t<>& conv_p,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* B,
+ int32_t* C) {
// filters are assumed to be in RSCK format
assert(conv_p.G == 1 && "Groups != 1 not supported yet");
for (int n = 0; n < conv_p.MB; ++n) {
- for (int h = 0; h < conv_p.OH; ++h) {
- for (int w = 0; w < conv_p.OW; ++w) {
+ for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) {
for (int k = 0; k < conv_p.OC; ++k) {
int sum = 0;
- for (int r = 0; r < conv_p.KH; ++r) {
- int h_in = -conv_p.pad_h + h * conv_p.stride_h + r;
- for (int s = 0; s < conv_p.KW; ++s) {
- int w_in = -conv_p.pad_w + w * conv_p.stride_w + s;
+ for (int r = 0; r < conv_p.K[0]; ++r) {
+ int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r;
+ for (int s = 0; s < conv_p.K[1]; ++s) {
+ int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s;
for (int c = 0; c < conv_p.IC; ++c) {
- int a = h_in < 0 || h_in >= conv_p.IH || w_in < 0 ||
- w_in >= conv_p.IW
+ int a = h_in < 0 || h_in >= conv_p.IN_DIM[0] || w_in < 0 ||
+ w_in >= conv_p.IN_DIM[1]
? A_zero_point
- : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) *
+ : A[((n * conv_p.IN_DIM[0] + h_in) * conv_p.IN_DIM[1] +
+ w_in) *
conv_p.IC +
c];
int b =
- B[((r * conv_p.KW + s) * conv_p.IC + c) * conv_p.OC + k];
+ B[((r * conv_p.K[1] + s) * conv_p.IC + c) * conv_p.OC + k];
sum += a * b;
} // for each c
} // for each s
} // for each r
- C[((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k] = sum;
+ C[((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) * conv_p.OC +
+ k] = sum;
} // for each k
} // for each w
} // for each h
} // for each n
}
+void conv3d_ref(
+ const conv_param_t<3>& conv_p,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* B,
+ int32_t* C) {
+ // filters are assumed to be in RSCK format
+ assert(conv_p.G == 1 && "Groups != 1 not supported yet");
+
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int t = 0; t < conv_p.OUT_DIM[0]; ++t) {
+ for (int h = 0; h < conv_p.OUT_DIM[1]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[2]; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int sum = 0;
+ for (int q = 0; q < conv_p.K[0]; ++q) {
+ int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q;
+ for (int r = 0; r < conv_p.K[1]; ++r) {
+ int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r;
+ for (int s = 0; s < conv_p.K[2]; ++s) {
+ int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s;
+ for (int c = 0; c < conv_p.IC; ++c) {
+ int a = t_in < 0 || t_in >= conv_p.IN_DIM[0] || h_in < 0 ||
+ h_in >= conv_p.IN_DIM[1] || w_in < 0 ||
+ w_in >= conv_p.IN_DIM[2]
+ ? A_zero_point
+ : A[(((n * conv_p.IN_DIM[0] + t_in) * conv_p.IN_DIM[1] +
+ h_in) *
+ conv_p.IN_DIM[2] +
+ w_in) *
+ conv_p.IC +
+ c];
+ int b =
+ B[(((q * conv_p.K[1] + r) * conv_p.K[2] + s) *
+ conv_p.IC +
+ c) *
+ conv_p.OC +
+ k];
+ sum += a * b;
+ } // for each c
+ } // for each s
+ } // for each r
+ } // for each q
+ C[(((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] + h) *
+ conv_p.OUT_DIM[2] +
+ w) *
+ conv_p.OC +
+ k] = sum;
+ } // for each k
+ } // for each w
+ } // for each h
+ } // for each t
+ } // for each n
+}
+
void depthwise_3x3_pad_1_ref(
int N,
int H,
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
index e9eaeed..69d060a 100644
--- a/src/RefImplementations.h
+++ b/src/RefImplementations.h
@@ -147,7 +147,14 @@ int32_t clip_16bit(int32_t x);
* The output C is assumed to be in NHoWoC format.
*/
void conv_ref(
- const conv_param_t& conv_p,
+ const conv_param_t<>& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ const std::int8_t* B,
+ std::int32_t* C);
+
+void conv3d_ref(
+ const conv_param_t<3>& conv_p,
const std::uint8_t* A,
std::int32_t A_zero_point,
const std::int8_t* B,
@@ -159,7 +166,18 @@ void conv_ref(
* The output A is assumed to be in NHoWoRSC format.
*/
void im2col_ref(
- const conv_param_t& conv_p,
+ const conv_param_t<>& conv_p,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ std::uint8_t* Ao);
+
+/*
+ * @brief Reference implementation of im2col 3D operation.
+ * The input A is assumed to be in NTiHiWiC format.
+ * The output A is assumed to be in NToHoWoK0K1K2C format.
+ */
+void im2col3d_ref(
+ const conv_param_t<3>& conv_p,
const std::uint8_t* A,
std::int32_t A_zero_point,
std::uint8_t* Ao);