From 5e71d2c304663f3b4e50cee723b8e98a867d11ca Mon Sep 17 00:00:00 2001 From: Daya Khudia Date: Wed, 12 Jun 2019 11:26:39 -0700 Subject: Print packed matrix for each group as well Summary: same as title. We were only printing packed matrix for group 0 Reviewed By: jianyuh Differential Revision: D15775235 fbshipit-source-id: 747550c9ae229a2eeb912409897c1331ada81e2b --- src/PackBMatrix.cc | 52 +++++++++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 23 deletions(-) (limited to 'src') diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index b6d06ca..76198ca 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -382,33 +382,39 @@ void PackBMatrix::printPackedMatrix(std::string name) { << "[" << BaseType::blockRowSize() << ", " << BaseType::blockColSize() << "]" << std::endl; - T* out = BaseType::getBuf(); - - for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { - auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() - : BaseType::blockRowSize(); - for (auto nc = 0; nc < BaseType::blockCols(); ++nc) { - std::cout << "block:" << nr << ", " << nc << std::endl; - auto cols = (nc == BaseType::blockCols() - 1) ? BaseType::lastBcol() - : BaseType::blockColSize(); - for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_; - ++r) { - for (auto c = 0; c < cols * row_interleave_; ++c) { - T val = - out[nr * BaseType::blockCols() * BaseType::blockRowSize() * - BaseType::blockColSize() + - nc * BaseType::blockRowSize() * BaseType::blockColSize() + - r * BaseType::blockColSize() * row_interleave_ + c]; - if (std::is_integral::value) { - // cast to int64 because cout doesn't print int8_t type directly - std::cout << std::setw(5) << static_cast(val) << " "; - } else { - std::cout << std::setw(5) << val << " "; + for (int g = 0; g < BaseType::numGroups(); ++g) { + T* out = BaseType::getBuf() + + g * + BaseType::packedBufferSize( + BaseType::numPackedRows(), BaseType::numPackedCols()); + std::cout << "group: " << g << std::endl; + for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { + auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() + : BaseType::blockRowSize(); + for (auto nc = 0; nc < BaseType::blockCols(); ++nc) { + std::cout << "block:" << nr << ", " << nc << std::endl; + auto cols = (nc == BaseType::blockCols() - 1) + ? BaseType::lastBcol() + : BaseType::blockColSize(); + for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_; + ++r) { + for (auto c = 0; c < cols * row_interleave_; ++c) { + T val = + out[nr * BaseType::blockCols() * BaseType::blockRowSize() * + BaseType::blockColSize() + + nc * BaseType::blockRowSize() * BaseType::blockColSize() + + r * BaseType::blockColSize() * row_interleave_ + c]; + if (std::is_integral::value) { + // cast to int64 because cout doesn't print int8_t type directly + std::cout << std::setw(5) << static_cast(val) << " "; + } else { + std::cout << std::setw(5) << val << " "; + } } + std::cout << std::endl; } std::cout << std::endl; } - std::cout << std::endl; } } } -- cgit v1.2.3