Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2018-11-29 06:41:59 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-11-29 06:44:08 +0300
commit027de07a11a0460fd1daffb026d50dba0e56eb79 (patch)
tree5e0f497059d6a22b18de8508c2031ffdbc9f52d3 /include/fbgemm/FbgemmI8Spmdm.h
parent90535d3da35f9d3da6a8dbd62da0c68d01696924 (diff)
sparse convolution output processing (#27)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/27 DoSpmdmOnInpBuffer can't be used together with PackAWithIm2Col because DoSpmdmOnInpBuffer expects im2col'ed A matrix. This diff implements DoSConvOnInpBuffer that does sparse convolution directly on A input without im2col. The performance is well optimized and need to see if this implementation is good enough to get good resnet50 performance. Reviewed By: dskhudia Differential Revision: D13192336 fbshipit-source-id: 2076555ba9749e111afbaec408a2bfa0f55bd5bc
Diffstat (limited to 'include/fbgemm/FbgemmI8Spmdm.h')
-rw-r--r--include/fbgemm/FbgemmI8Spmdm.h35
1 files changed, 33 insertions, 2 deletions
diff --git a/include/fbgemm/FbgemmI8Spmdm.h b/include/fbgemm/FbgemmI8Spmdm.h
index 3e040ad..ec72473 100644
--- a/include/fbgemm/FbgemmI8Spmdm.h
+++ b/include/fbgemm/FbgemmI8Spmdm.h
@@ -8,6 +8,7 @@
#include <cstdint>
#include <vector>
+#include "ConvUtils.h"
#include "Utils.h"
// #define FBGEMM_MEASURE_TIME_BREAKDOWN
@@ -21,6 +22,7 @@ extern double spmdm_transpose_32xN_time;
extern double spmdm_compute_time;
extern double spmdm_transpose_Nx32_time;
extern double spmdm_run_time;
+extern double sconv_run_time;
#endif
namespace fbgemm {
@@ -46,6 +48,19 @@ class CompressedSparseColumn {
std::vector<std::int8_t>& Values() {
return values_;
}
+ std::vector<std::int16_t>& KHs() {
+ return kh_;
+ }
+ std::vector<std::int16_t>& KWs() {
+ return kw_;
+ }
+ /**
+ * ICs include group: i.e. for ith input channels withint group g, ICs contain
+ * g*(groups_per_input_channels) + i
+ */
+ std::vector<std::int16_t>& ICs() {
+ return ic_;
+ }
std::size_t NumOfRows() const {
return num_rows_;
@@ -83,12 +98,28 @@ class CompressedSparseColumn {
std::int32_t* C,
int ldc) const;
+ void SparseConv(
+ const conv_param_t<>& conv_p,
+ const block_type_t& block,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ bool accumulation,
+ std::int32_t* C,
+ int ldc) const;
+
private:
const std::size_t num_rows_;
- std::vector<std::int32_t> colptr_;
- std::vector<std::int16_t> rowidx_;
+ std::vector<std::int32_t> colptr_; // corresponds to out channels
std::vector<std::int8_t> values_;
+ // For SpMDM
+ std::vector<std::int16_t> rowidx_; // kh kw ic are flattened with im2col
+
+ // For direct sparse convolution
+ std::vector<std::int16_t> kh_;
+ std::vector<std::int16_t> kw_;
+ std::vector<std::int16_t> ic_; // in channels
+
// Cache IsHyperSparse to minimize its overhead.
mutable bool hyper_sparse_;