diff options
author | Daya Khudia <dskhudia@fb.com> | 2019-06-12 21:26:39 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-06-12 21:32:28 +0300 |
commit | 5e71d2c304663f3b4e50cee723b8e98a867d11ca (patch) | |
tree | a3932b85eef3c74abecaf96e2a2d5b32261ef98e /src | |
parent | bf2f45f35cc0d7b6f420894652824a377f764714 (diff) |
Print packed matrix for each group as well
Summary: same as title. We were only printing packed matrix for group 0
Reviewed By: jianyuh
Differential Revision: D15775235
fbshipit-source-id: 747550c9ae229a2eeb912409897c1331ada81e2b
Diffstat (limited to 'src')
-rw-r--r-- | src/PackBMatrix.cc | 52 |
1 files changed, 29 insertions, 23 deletions
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index b6d06ca..76198ca 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -382,33 +382,39 @@ void PackBMatrix<T, accT>::printPackedMatrix(std::string name) { << "[" << BaseType::blockRowSize() << ", " << BaseType::blockColSize() << "]" << std::endl; - T* out = BaseType::getBuf(); - - for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { - auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() - : BaseType::blockRowSize(); - for (auto nc = 0; nc < BaseType::blockCols(); ++nc) { - std::cout << "block:" << nr << ", " << nc << std::endl; - auto cols = (nc == BaseType::blockCols() - 1) ? BaseType::lastBcol() - : BaseType::blockColSize(); - for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_; - ++r) { - for (auto c = 0; c < cols * row_interleave_; ++c) { - T val = - out[nr * BaseType::blockCols() * BaseType::blockRowSize() * - BaseType::blockColSize() + - nc * BaseType::blockRowSize() * BaseType::blockColSize() + - r * BaseType::blockColSize() * row_interleave_ + c]; - if (std::is_integral<T>::value) { - // cast to int64 because cout doesn't print int8_t type directly - std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; - } else { - std::cout << std::setw(5) << val << " "; + for (int g = 0; g < BaseType::numGroups(); ++g) { + T* out = BaseType::getBuf() + + g * + BaseType::packedBufferSize( + BaseType::numPackedRows(), BaseType::numPackedCols()); + std::cout << "group: " << g << std::endl; + for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { + auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() + : BaseType::blockRowSize(); + for (auto nc = 0; nc < BaseType::blockCols(); ++nc) { + std::cout << "block:" << nr << ", " << nc << std::endl; + auto cols = (nc == BaseType::blockCols() - 1) + ? BaseType::lastBcol() + : BaseType::blockColSize(); + for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_; + ++r) { + for (auto c = 0; c < cols * row_interleave_; ++c) { + T val = + out[nr * BaseType::blockCols() * BaseType::blockRowSize() * + BaseType::blockColSize() + + nc * BaseType::blockRowSize() * BaseType::blockColSize() + + r * BaseType::blockColSize() * row_interleave_ + c]; + if (std::is_integral<T>::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast<int64_t>(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } } + std::cout << std::endl; } std::cout << std::endl; } - std::cout << std::endl; } } } |