diff options
author | dskhudia <dskhudia@fb.com> | 2018-11-06 00:17:52 +0300 |
---|---|---|
committer | dskhudia <dskhudia@fb.com> | 2018-11-06 00:17:52 +0300 |
commit | b96bc0bf311f7abdc83ffd3af0a485b4aef53f7c (patch) | |
tree | 2a6c276d20753abe94c526aab7b109305e3d1d78 /src | |
parent | 14adee1ac506e067489406af689ae9b73fb581bd (diff) |
generalized conv_param_t and download third party libraries in build dir
Diffstat (limited to 'src')
-rw-r--r-- | src/ExecuteKernelU8S8.cc | 12 | ||||
-rw-r--r-- | src/Fbgemm.cc | 20 | ||||
-rw-r--r-- | src/PackAWithIm2Col.cc | 143 | ||||
-rw-r--r-- | src/PackMatrix.cc | 8 | ||||
-rw-r--r-- | src/RefImplementations.cc | 214 | ||||
-rw-r--r-- | src/RefImplementations.h | 22 |
6 files changed, 348 insertions, 71 deletions
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index 5145869..e091a87 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -322,6 +322,12 @@ template class ExecuteKernel< memCopy<>>; template class ExecuteKernel< + PackAWithIm2Col<uint8_t, int16_t, 3>, + PackBMatrix<int8_t, int16_t>, + int32_t, + memCopy<>>; + +template class ExecuteKernel< PackAWithRowOffset<uint8_t, int32_t>, PackBMatrix<int8_t, int32_t>, int32_t, @@ -334,6 +340,12 @@ template class ExecuteKernel< memCopy<>>; template class ExecuteKernel< + PackAWithIm2Col<uint8_t, int32_t, 3>, + PackBMatrix<int8_t, int32_t>, + int32_t, + memCopy<>>; + +template class ExecuteKernel< PackAWithQuantRowOffset<uint8_t, int32_t>, PackBMatrix<int8_t, int32_t>, int32_t, diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index f3bac97..9195a05 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -236,6 +236,16 @@ template void fbgemmPacked( int num_threads); template void fbgemmPacked( + PackMatrix<PackAWithIm2Col<uint8_t, int32_t, 3>, uint8_t, int32_t>& packA, + PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( PackMatrix<PackAWithQuantRowOffset<uint8_t, int32_t>, uint8_t, int32_t>& packA, PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>& packB, @@ -341,6 +351,16 @@ template void fbgemmPacked( int num_threads); template void fbgemmPacked( + PackMatrix<PackAWithIm2Col<uint8_t, int16_t, 3>, uint8_t, int16_t>& packA, + PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, + int32_t* C, + int32_t* C_buffer, + uint32_t ldc, + const memCopy<>& outProcess, + int thread_id, + int num_threads); + +template void fbgemmPacked( PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA, PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, int32_t* C, diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc index e067a3e..8dde696 100644 --- a/src/PackAWithIm2Col.cc +++ b/src/PackAWithIm2Col.cc @@ -4,26 +4,37 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include <cpuinfo.h> +#include <algorithm> #include <cassert> #include <iomanip> #include <iostream> -#include "fbgemm/Fbgemm.h" +#include <numeric> +#include <cpuinfo.h> -#include <algorithm> +#include "fbgemm/Fbgemm.h" namespace fbgemm2 { -template <typename T, typename accT> -PackAWithIm2Col<T, accT>::PackAWithIm2Col( - const conv_param_t& conv_p, +template <typename T, typename accT, int SPATIAL_DIM> +PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col( + const conv_param_t<SPATIAL_DIM>& conv_p, const T* sdata, inpType* pmat, int32_t zero_pt, int32_t* row_offset) - : PackMatrix<PackAWithIm2Col<T, accT>, T, accT>( - conv_p.MB * conv_p.OH * conv_p.OW, - conv_p.KH * conv_p.KW * conv_p.IC, + : PackMatrix<PackAWithIm2Col<T, accT, SPATIAL_DIM>, T, accT>( + conv_p.MB * + std::accumulate( + conv_p.OUT_DIM.begin(), + conv_p.OUT_DIM.end(), + 1, + std::multiplies<int>()), + std::accumulate( + conv_p.K.begin(), + conv_p.K.end(), + 1, + std::multiplies<int>()) * + conv_p.IC, pmat, zero_pt), conv_p_(conv_p), @@ -62,8 +73,8 @@ PackAWithIm2Col<T, accT>::PackAWithIm2Col( } } -template <typename T, typename accT> -void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) { +template <typename T, typename accT, int SPATIAL_DIM> +void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) { block_type_t block_p = {block.row_start, block.row_size, block.col_start, @@ -72,11 +83,87 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) { BaseType::packedBlock(block_p); T* out = BaseType::getBuf(); + if (SPATIAL_DIM == 3) { // static if + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + int n = + i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]); + int thw = + i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1] * conv_p_.OUT_DIM[2]); + int w = thw % conv_p_.OUT_DIM[2]; + int h = thw / conv_p_.OUT_DIM[2] % conv_p_.OUT_DIM[1]; + int t = thw / conv_p_.OUT_DIM[2] / conv_p_.OUT_DIM[1]; + for (int j = block.col_start; + j < block.col_start + block.col_size + conv_p_.IC - 1; + j += conv_p_.IC) { + int j_blk_id = j / conv_p_.IC; + // max( j_blk_id * IC, START) -> min( END, (j_blk_id + 1) * IC ) + int j_blk_start = std::max(j_blk_id * conv_p_.IC, block.col_start); + int j_blk_end = std::min( + (j_blk_id + 1) * conv_p_.IC, block.col_start + block.col_size); + if (j_blk_start >= j_blk_end) { + break; + } + + int qrs = j / conv_p_.IC; + int s = qrs % conv_p_.K[2]; + int r = qrs / conv_p_.K[2] % conv_p_.K[1]; + int q = qrs / conv_p_.K[2] / conv_p_.K[1]; + + int t_in = -conv_p_.pad[0] + t * conv_p_.stride[0] + q; + int h_in = -conv_p_.pad[1] + h * conv_p_.stride[1] + r; + int w_in = -conv_p_.pad[2] + w * conv_p_.stride[2] + s; + + if (t_in < 0 || t_in >= conv_p_.IN_DIM[0] || h_in < 0 || + h_in >= conv_p_.IN_DIM[1] || w_in < 0 || + w_in >= conv_p_.IN_DIM[2]) { + // Please note that padding for convolution should be filled with + // zero_pt + std::memset( + &out + [(i - block.row_start) * BaseType::blockColSize() + + (j_blk_start - block.col_start)], + BaseType::zeroPoint(), + sizeof(T) * (j_blk_end - j_blk_start)); + } else { + std::memcpy( + &out + [(i - block.row_start) * BaseType::blockColSize() + + j_blk_start - block.col_start], + &sdata_ + [(((n * conv_p_.IN_DIM[0] + t_in) * conv_p_.IN_DIM[1] + + h_in) * + conv_p_.IN_DIM[2] + + w_in) * + conv_p_.IC + + (j_blk_start % conv_p_.IC)], + sizeof(T) * (j_blk_end - j_blk_start)); + } + } + // zero fill + // Please see the comment in PackAMatrix.cc for zero vs zero_pt fill. + if ((block_p.col_start + block_p.col_size) - + (block.col_start + block.col_size) > + 0) { + std::memset( + &out + [(i - block.row_start) * BaseType::blockColSize() + + (block.col_size)], + 0, + sizeof(T) * + ((block_p.col_start + block_p.col_size) - + (block.col_start + block.col_size))); + } + } + return; + } + + assert(SPATIAL_DIM == 2 && "unsupported conv dimension"); + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { - int n = i / (conv_p_.OH * conv_p_.OW); - int hw = i % (conv_p_.OH * conv_p_.OW); - int w = hw % conv_p_.OW; - int h = hw / conv_p_.OW; + int n = i / (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]); + int hw = i % (conv_p_.OUT_DIM[0] * conv_p_.OUT_DIM[1]); + int w = hw % conv_p_.OUT_DIM[1]; + int h = hw / conv_p_.OUT_DIM[1]; for (int j = block.col_start; j < block.col_start + block.col_size + conv_p_.IC - 1; j += conv_p_.IC) { @@ -90,13 +177,14 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) { } int rs = j / conv_p_.IC; - int s = rs % conv_p_.KW; - int r = rs / conv_p_.KW; + int s = rs % conv_p_.K[1]; + int r = rs / conv_p_.K[1]; - int w_in = -conv_p_.pad_w + w * conv_p_.stride_w + s; - int h_in = -conv_p_.pad_h + h * conv_p_.stride_h + r; + int h_in = -conv_p_.pad[0] + h * conv_p_.stride[0] + r; + int w_in = -conv_p_.pad[1] + w * conv_p_.stride[1] + s; - if (h_in < 0 || h_in >= conv_p_.IH || w_in < 0 || w_in >= conv_p_.IW) { + if (h_in < 0 || h_in >= conv_p_.IN_DIM[0] || w_in < 0 || + w_in >= conv_p_.IN_DIM[1]) { // Please note that padding for convolution should be filled with // zero_pt std::memset( @@ -111,7 +199,8 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) { [(i - block.row_start) * BaseType::blockColSize() + j_blk_start - block.col_start], &sdata_ - [((n * conv_p_.IH + h_in) * conv_p_.IW + w_in) * conv_p_.IC + + [((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] + w_in) * + conv_p_.IC + (j_blk_start % conv_p_.IC)], sizeof(T) * (j_blk_end - j_blk_start)); } @@ -133,8 +222,9 @@ void PackAWithIm2Col<T, accT>::pack(const block_type_t& block) { } } -template <typename T, typename accT> -void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) { +template <typename T, typename accT, int SPATIAL_DIM> +void PackAWithIm2Col<T, accT, SPATIAL_DIM>::printPackedMatrix( + std::string name) { std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; @@ -155,8 +245,8 @@ void PackAWithIm2Col<T, accT>::printPackedMatrix(std::string name) { std::cout << std::endl; } -template <typename T, typename accT> -int PackAWithIm2Col<T, accT>::rowOffsetBufferSize() { +template <typename T, typename accT, int SPATIAL_DIM> +int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize() { if (cpuinfo_initialize()) { if (cpuinfo_has_x86_avx512f()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; @@ -174,4 +264,7 @@ int PackAWithIm2Col<T, accT>::rowOffsetBufferSize() { template class PackAWithIm2Col<uint8_t, int32_t>; template class PackAWithIm2Col<uint8_t, int16_t>; +template class PackAWithIm2Col<uint8_t, int32_t, 3>; +template class PackAWithIm2Col<uint8_t, int16_t, 3>; + } // namespace fbgemm2 diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc index 85000ac..37b4e88 100644 --- a/src/PackMatrix.cc +++ b/src/PackMatrix.cc @@ -64,6 +64,10 @@ template class PackMatrix< int32_t>; template class PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>; +template class PackMatrix< + PackAWithIm2Col<uint8_t, int32_t, 3>, + uint8_t, + int32_t>; template class PackMatrix< PackAWithQuantRowOffset<uint8_t, int32_t>, @@ -74,6 +78,10 @@ template class PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>; // int16 accumulation template class PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>; +template class PackMatrix< + PackAWithIm2Col<uint8_t, int16_t, 3>, + uint8_t, + int16_t>; template class PackMatrix< PackAWithRowOffset<uint8_t, int16_t>, diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 6bf2d65..4b919c1 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -219,10 +219,10 @@ void spmdm_ref( } int32_t clip_16bit(int32_t x) { - if (x > std::numeric_limits<int16_t>::max()) { - return std::min<int>(std::numeric_limits<int16_t>::max(), x); - } else if (x < std::numeric_limits<int16_t>::min()) { - return std::max<int>(std::numeric_limits<int16_t>::min(), x); + if (x > numeric_limits<int16_t>::max()) { + return std::min<int>(numeric_limits<int16_t>::max(), x); + } else if (x < numeric_limits<int16_t>::min()) { + return std::max<int>(numeric_limits<int16_t>::min(), x); } else { return x; } @@ -235,36 +235,38 @@ int32_t clip_16bit(int32_t x) { * Ao: NHWC: NH_1W_1 x RSC_0 */ void im2col_ref( - const conv_param_t& conv_p, - const std::uint8_t* A, - std::int32_t A_zero_point, - std::uint8_t* Ao) { + const conv_param_t<>& conv_p, + const uint8_t* A, + int32_t A_zero_point, + uint8_t* Ao) { for (int n = 0; n < conv_p.MB; ++n) { - for (int h = 0; h < conv_p.OH; ++h) { - for (int w = 0; w < conv_p.OW; ++w) { - for (int r = 0; r < conv_p.KH; ++r) { - int h_in = -conv_p.pad_h + h * conv_p.stride_h + r; - for (int s = 0; s < conv_p.KW; ++s) { - int w_in = -conv_p.pad_w + w * conv_p.stride_w + s; - if (h_in < 0 || h_in >= conv_p.IH || w_in < 0 || - w_in >= conv_p.IW) { - std::memset( - &Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) * - conv_p.KW + + for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) { + for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) { + for (int r = 0; r < conv_p.K[0]; ++r) { + int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r; + for (int s = 0; s < conv_p.K[1]; ++s) { + int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s; + if (h_in < 0 || h_in >= conv_p.IN_DIM[0] || w_in < 0 || + w_in >= conv_p.IN_DIM[1]) { + memset( + &Ao[((((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) * + conv_p.K[0] + + r) * + conv_p.K[1] + s) * - conv_p.IC + - 0], + conv_p.IC], A_zero_point, sizeof(uint8_t) * conv_p.IC); } else { - std::memcpy( - &Ao[((((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.KH + r) * - conv_p.KW + + memcpy( + &Ao[((((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) * + conv_p.K[0] + + r) * + conv_p.K[1] + s) * - conv_p.IC + - 0], - &A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * conv_p.IC + - 0], + conv_p.IC], + &A[((n * conv_p.IN_DIM[0] + h_in) * conv_p.IN_DIM[1] + w_in) * + conv_p.IC], sizeof(uint8_t) * conv_p.IC); } } // for each s @@ -274,44 +276,168 @@ void im2col_ref( } // for each n } +/* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function + * from caffe2/utils/math_cpu.cc + * NHWC StorageOrder/Layout + * A: NHWC: NT_0H_0W_0 x C_0 + * Ao: NHWC: NT_1H_1W_1 x QRSC_0 + */ +void im2col3d_ref( + const conv_param_t<3>& conv_p, + const uint8_t* A, + int32_t A_zero_point, + uint8_t* Ao) { + for (int n = 0; n < conv_p.MB; ++n) { + for (int t = 0; t < conv_p.OUT_DIM[0]; ++t) { + for (int h = 0; h < conv_p.OUT_DIM[1]; ++h) { + for (int w = 0; w < conv_p.OUT_DIM[2]; ++w) { + for (int q = 0; q < conv_p.K[0]; ++q) { + int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q; + for (int r = 0; r < conv_p.K[1]; ++r) { + int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r; + for (int s = 0; s < conv_p.K[2]; ++s) { + int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s; + if (t_in < 0 || t_in >= conv_p.IN_DIM[0] || h_in < 0 || + h_in >= conv_p.IN_DIM[1] || w_in < 0 || + w_in >= conv_p.IN_DIM[2]) { + memset( + &Ao[((((((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] + + h) * + conv_p.OUT_DIM[2] + + w) * + conv_p.K[0] + + q) * + conv_p.K[1] + + r) * + conv_p.K[2] + + s) * + conv_p.IC], + A_zero_point, + sizeof(uint8_t) * conv_p.IC); + } else { + memcpy( + &Ao[((((((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] + + h) * + conv_p.OUT_DIM[2] + + w) * + conv_p.K[0] + + q) * + conv_p.K[1] + + r) * + conv_p.K[2] + + s) * + conv_p.IC], + &A[(((n * conv_p.IN_DIM[0] + t_in) * conv_p.IN_DIM[1] + + h_in) * + conv_p.IN_DIM[2] + + w_in) * + conv_p.IC], + sizeof(uint8_t) * conv_p.IC); + } + } // for each s + } // for each r + } // for each q + } // for each w + } // for each h + } // for each t + } // for each n +} + void conv_ref( - const conv_param_t& conv_p, - const std::uint8_t* A, - std::int32_t A_zero_point, - const std::int8_t* B, - std::int32_t* C) { + const conv_param_t<>& conv_p, + const uint8_t* A, + int32_t A_zero_point, + const int8_t* B, + int32_t* C) { // filters are assumed to be in RSCK format assert(conv_p.G == 1 && "Groups != 1 not supported yet"); for (int n = 0; n < conv_p.MB; ++n) { - for (int h = 0; h < conv_p.OH; ++h) { - for (int w = 0; w < conv_p.OW; ++w) { + for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) { + for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) { for (int k = 0; k < conv_p.OC; ++k) { int sum = 0; - for (int r = 0; r < conv_p.KH; ++r) { - int h_in = -conv_p.pad_h + h * conv_p.stride_h + r; - for (int s = 0; s < conv_p.KW; ++s) { - int w_in = -conv_p.pad_w + w * conv_p.stride_w + s; + for (int r = 0; r < conv_p.K[0]; ++r) { + int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r; + for (int s = 0; s < conv_p.K[1]; ++s) { + int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s; for (int c = 0; c < conv_p.IC; ++c) { - int a = h_in < 0 || h_in >= conv_p.IH || w_in < 0 || - w_in >= conv_p.IW + int a = h_in < 0 || h_in >= conv_p.IN_DIM[0] || w_in < 0 || + w_in >= conv_p.IN_DIM[1] ? A_zero_point - : A[((n * conv_p.IH + h_in) * conv_p.IW + w_in) * + : A[((n * conv_p.IN_DIM[0] + h_in) * conv_p.IN_DIM[1] + + w_in) * conv_p.IC + c]; int b = - B[((r * conv_p.KW + s) * conv_p.IC + c) * conv_p.OC + k]; + B[((r * conv_p.K[1] + s) * conv_p.IC + c) * conv_p.OC + k]; sum += a * b; } // for each c } // for each s } // for each r - C[((n * conv_p.OH + h) * conv_p.OW + w) * conv_p.OC + k] = sum; + C[((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) * conv_p.OC + + k] = sum; } // for each k } // for each w } // for each h } // for each n } +void conv3d_ref( + const conv_param_t<3>& conv_p, + const uint8_t* A, + int32_t A_zero_point, + const int8_t* B, + int32_t* C) { + // filters are assumed to be in RSCK format + assert(conv_p.G == 1 && "Groups != 1 not supported yet"); + + for (int n = 0; n < conv_p.MB; ++n) { + for (int t = 0; t < conv_p.OUT_DIM[0]; ++t) { + for (int h = 0; h < conv_p.OUT_DIM[1]; ++h) { + for (int w = 0; w < conv_p.OUT_DIM[2]; ++w) { + for (int k = 0; k < conv_p.OC; ++k) { + int sum = 0; + for (int q = 0; q < conv_p.K[0]; ++q) { + int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q; + for (int r = 0; r < conv_p.K[1]; ++r) { + int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r; + for (int s = 0; s < conv_p.K[2]; ++s) { + int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s; + for (int c = 0; c < conv_p.IC; ++c) { + int a = t_in < 0 || t_in >= conv_p.IN_DIM[0] || h_in < 0 || + h_in >= conv_p.IN_DIM[1] || w_in < 0 || + w_in >= conv_p.IN_DIM[2] + ? A_zero_point + : A[(((n * conv_p.IN_DIM[0] + t_in) * conv_p.IN_DIM[1] + + h_in) * + conv_p.IN_DIM[2] + + w_in) * + conv_p.IC + + c]; + int b = + B[(((q * conv_p.K[1] + r) * conv_p.K[2] + s) * + conv_p.IC + + c) * + conv_p.OC + + k]; + sum += a * b; + } // for each c + } // for each s + } // for each r + } // for each q + C[(((n * conv_p.OUT_DIM[0] + t) * conv_p.OUT_DIM[1] + h) * + conv_p.OUT_DIM[2] + + w) * + conv_p.OC + + k] = sum; + } // for each k + } // for each w + } // for each h + } // for each t + } // for each n +} + void depthwise_3x3_pad_1_ref( int N, int H, diff --git a/src/RefImplementations.h b/src/RefImplementations.h index e9eaeed..69d060a 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -147,7 +147,14 @@ int32_t clip_16bit(int32_t x); * The output C is assumed to be in NHoWoC format. */ void conv_ref( - const conv_param_t& conv_p, + const conv_param_t<>& conv_p, + const std::uint8_t* A, + std::int32_t A_zero_point, + const std::int8_t* B, + std::int32_t* C); + +void conv3d_ref( + const conv_param_t<3>& conv_p, const std::uint8_t* A, std::int32_t A_zero_point, const std::int8_t* B, @@ -159,7 +166,18 @@ void conv_ref( * The output A is assumed to be in NHoWoRSC format. */ void im2col_ref( - const conv_param_t& conv_p, + const conv_param_t<>& conv_p, + const std::uint8_t* A, + std::int32_t A_zero_point, + std::uint8_t* Ao); + +/* + * @brief Reference implementation of im2col 3D operation. + * The input A is assumed to be in NTiHiWiC format. + * The output A is assumed to be in NToHoWoK0K1K2C format. + */ +void im2col3d_ref( + const conv_param_t<3>& conv_p, const std::uint8_t* A, std::int32_t A_zero_point, std::uint8_t* Ao); |