diff options
author | Daya Khudia <dskhudia@fb.com> | 2019-07-16 03:34:48 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-07-16 03:47:38 +0300 |
commit | feca34d3d0b52527a1caa53186bc2ebf72d227c2 (patch) | |
tree | 09c856f5f8bbe03ff8d594c8c3083f6cb244ae49 | |
parent | e69972dad13a116049a8f9e6657c74ac5d04207a (diff) |
Add functions needed for unpacking in PackWeightsForConv (#106)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/106
The values returned by these functions is needed while unpacking weights.
Reviewed By: jianyuh
Differential Revision: D16193425
fbshipit-source-id: 8ee3a0dc46768d7cb572bf383be1ce2b450c44c9
-rw-r--r-- | include/fbgemm/Fbgemm.h | 17 | ||||
-rw-r--r-- | src/PackWeightsForConv.cc | 3 |
2 files changed, 19 insertions, 1 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 302af51..9cee28f 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -597,6 +597,22 @@ class FBGEMM_API PackWeightsForConv { return W_gconv_packed_; } + int inputChannels() { + return conv_param_.IC; + } + + int outputChannels() { + return conv_param_.OC; + } + + std::array<int, SPATIAL_DIM> kernelDims() { + return conv_param_.K; + } + + int groups() { + return conv_param_.G; + } + /** * @brief Unpack packed matric into origin_buf (Used for the serialization to * recover weight matrix). @@ -604,6 +620,7 @@ class FBGEMM_API PackWeightsForConv { void unpack(T* origin_buf); private: + const conv_param_t<SPATIAL_DIM> conv_param_; // Packed weights if we use im2col based convolution implementation std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_; // Packed weights if we use 2D depthwise convolution implementation diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index e16843c..085adc0 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -13,7 +13,8 @@ template <int SPATIAL_DIM, typename T, typename accT> PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( const conv_param_t<SPATIAL_DIM>& conv_p, const T* sdata, - const BlockingFactors* blocking_params) { + const BlockingFactors* blocking_params) + : conv_param_(conv_p) { static_assert( SPATIAL_DIM == 2 || SPATIAL_DIM == 3, "Only 2D and 3D convolutions are supported"); |