Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/PackWeightsForConv.cc')
-rw-r--r--src/PackWeightsForConv.cc24
1 files changed, 24 insertions, 0 deletions
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index 085adc0..ccb97fd 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -4,6 +4,7 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
+#include <algorithm>
#include <memory>
#include "fbgemm/Fbgemm.h"
@@ -81,6 +82,29 @@ void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) {
}
}
+template <int SPATIAL_DIM, typename T, typename accT>
+bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant(
+ const conv_param_t<SPATIAL_DIM>& test_conv_p) {
+ return conv_param_.IC == test_conv_p.IC && conv_param_.OC == test_conv_p.OC &&
+ conv_param_.G == test_conv_p.G &&
+ std::equal(
+ conv_param_.K.begin(),
+ conv_param_.K.end(),
+ test_conv_p.K.begin()) &&
+ std::equal(
+ conv_param_.stride.begin(),
+ conv_param_.stride.end(),
+ test_conv_p.stride.begin()) &&
+ std::equal(
+ conv_param_.pad.begin(),
+ conv_param_.pad.end(),
+ test_conv_p.pad.begin()) &&
+ std::equal(
+ conv_param_.dilation.begin(),
+ conv_param_.dilation.end(),
+ test_conv_p.dilation.begin());
+}
+
template class PackWeightsForConv<2, int8_t, int32_t>;
template class PackWeightsForConv<3, int8_t, int32_t>;