diff options
Diffstat (limited to 'src/RefImplementations.cc')
-rw-r--r-- | src/RefImplementations.cc | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 72ef93f..b4b0c2b 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -285,8 +285,9 @@ int32_t clip_16bit(int32_t x) { * A: NHWC: NH_0W_0 x C_0 * Ao: NHWC: NH_1W_1 x G RS C_0/G */ +template <> void im2col_ref( - const conv_param_t<>& conv_p, + const conv_param_t<2>& conv_p, const uint8_t* A, int32_t A_zero_point, uint8_t* Ao) { @@ -346,7 +347,8 @@ void im2col_ref( * A: NHWC: NT_0H_0W_0 x C_0 * Ao: NHWC: NT_1H_1W_1 x G QRS C_0/G */ -void im2col3d_ref( +template <> +void im2col_ref( const conv_param_t<3>& conv_p, const uint8_t* A, int32_t A_zero_point, @@ -422,8 +424,10 @@ void im2col3d_ref( } // for each n } +// 2D Conv +template <> void conv_ref( - const conv_param_t<>& conv_p, + const conv_param_t<2>& conv_p, const uint8_t* A, int32_t A_zero_point, const int8_t* B, @@ -471,7 +475,9 @@ void conv_ref( } // for each n } -void conv3d_ref( +// 3D Conv +template <> +void conv_ref( const conv_param_t<3>& conv_p, const uint8_t* A, int32_t A_zero_point, @@ -531,10 +537,12 @@ void conv3d_ref( } // for each n } +template <int SPATIAL_DIM> void transposeConvWeights( - const conv_param_t<>& conv_p, + const conv_param_t<SPATIAL_DIM>& conv_p, const std::int8_t* src, std::int8_t* dest) { + assert(SPATIAL_DIM == 2 && "Only 2D supported currently"); int R = conv_p.K[0]; int S = conv_p.K[1]; int G = conv_p.G; @@ -956,4 +964,14 @@ void depthwise_3x3x3_per_channel_quantization_pad_1_ref( } }; +template void transposeConvWeights( + const conv_param_t<2>& conv_p, + const std::int8_t* src, + std::int8_t* dest); + +template void transposeConvWeights( + const conv_param_t<3>& conv_p, + const std::int8_t* src, + std::int8_t* dest); + } // namespace fbgemm |