Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/RefImplementations.cc')
-rw-r--r--src/RefImplementations.cc28
1 files changed, 23 insertions, 5 deletions
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index 72ef93f..b4b0c2b 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -285,8 +285,9 @@ int32_t clip_16bit(int32_t x) {
* A: NHWC: NH_0W_0 x C_0
* Ao: NHWC: NH_1W_1 x G RS C_0/G
*/
+template <>
void im2col_ref(
- const conv_param_t<>& conv_p,
+ const conv_param_t<2>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
uint8_t* Ao) {
@@ -346,7 +347,8 @@ void im2col_ref(
* A: NHWC: NT_0H_0W_0 x C_0
* Ao: NHWC: NT_1H_1W_1 x G QRS C_0/G
*/
-void im2col3d_ref(
+template <>
+void im2col_ref(
const conv_param_t<3>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
@@ -422,8 +424,10 @@ void im2col3d_ref(
} // for each n
}
+// 2D Conv
+template <>
void conv_ref(
- const conv_param_t<>& conv_p,
+ const conv_param_t<2>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
const int8_t* B,
@@ -471,7 +475,9 @@ void conv_ref(
} // for each n
}
-void conv3d_ref(
+// 3D Conv
+template <>
+void conv_ref(
const conv_param_t<3>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
@@ -531,10 +537,12 @@ void conv3d_ref(
} // for each n
}
+template <int SPATIAL_DIM>
void transposeConvWeights(
- const conv_param_t<>& conv_p,
+ const conv_param_t<SPATIAL_DIM>& conv_p,
const std::int8_t* src,
std::int8_t* dest) {
+ assert(SPATIAL_DIM == 2 && "Only 2D supported currently");
int R = conv_p.K[0];
int S = conv_p.K[1];
int G = conv_p.G;
@@ -956,4 +964,14 @@ void depthwise_3x3x3_per_channel_quantization_pad_1_ref(
}
};
+template void transposeConvWeights(
+ const conv_param_t<2>& conv_p,
+ const std::int8_t* src,
+ std::int8_t* dest);
+
+template void transposeConvWeights(
+ const conv_param_t<3>& conv_p,
+ const std::int8_t* src,
+ std::int8_t* dest);
+
} // namespace fbgemm