Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDaya Khudia <dskhudia@fb.com>2019-06-12 21:26:39 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-06-12 21:32:28 +0300
commit5e71d2c304663f3b4e50cee723b8e98a867d11ca (patch)
treea3932b85eef3c74abecaf96e2a2d5b32261ef98e /src
parentbf2f45f35cc0d7b6f420894652824a377f764714 (diff)
Print packed matrix for each group as well
Summary: same as title. We were only printing packed matrix for group 0 Reviewed By: jianyuh Differential Revision: D15775235 fbshipit-source-id: 747550c9ae229a2eeb912409897c1331ada81e2b
Diffstat (limited to 'src')
-rw-r--r--src/PackBMatrix.cc52
1 files changed, 29 insertions, 23 deletions
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
index b6d06ca..76198ca 100644
--- a/src/PackBMatrix.cc
+++ b/src/PackBMatrix.cc
@@ -382,33 +382,39 @@ void PackBMatrix<T, accT>::printPackedMatrix(std::string name) {
<< "[" << BaseType::blockRowSize() << ", "
<< BaseType::blockColSize() << "]" << std::endl;
- T* out = BaseType::getBuf();
-
- for (auto nr = 0; nr < BaseType::blockRows(); ++nr) {
- auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow()
- : BaseType::blockRowSize();
- for (auto nc = 0; nc < BaseType::blockCols(); ++nc) {
- std::cout << "block:" << nr << ", " << nc << std::endl;
- auto cols = (nc == BaseType::blockCols() - 1) ? BaseType::lastBcol()
- : BaseType::blockColSize();
- for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_;
- ++r) {
- for (auto c = 0; c < cols * row_interleave_; ++c) {
- T val =
- out[nr * BaseType::blockCols() * BaseType::blockRowSize() *
- BaseType::blockColSize() +
- nc * BaseType::blockRowSize() * BaseType::blockColSize() +
- r * BaseType::blockColSize() * row_interleave_ + c];
- if (std::is_integral<T>::value) {
- // cast to int64 because cout doesn't print int8_t type directly
- std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
- } else {
- std::cout << std::setw(5) << val << " ";
+ for (int g = 0; g < BaseType::numGroups(); ++g) {
+ T* out = BaseType::getBuf() +
+ g *
+ BaseType::packedBufferSize(
+ BaseType::numPackedRows(), BaseType::numPackedCols());
+ std::cout << "group: " << g << std::endl;
+ for (auto nr = 0; nr < BaseType::blockRows(); ++nr) {
+ auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow()
+ : BaseType::blockRowSize();
+ for (auto nc = 0; nc < BaseType::blockCols(); ++nc) {
+ std::cout << "block:" << nr << ", " << nc << std::endl;
+ auto cols = (nc == BaseType::blockCols() - 1)
+ ? BaseType::lastBcol()
+ : BaseType::blockColSize();
+ for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_;
+ ++r) {
+ for (auto c = 0; c < cols * row_interleave_; ++c) {
+ T val =
+ out[nr * BaseType::blockCols() * BaseType::blockRowSize() *
+ BaseType::blockColSize() +
+ nc * BaseType::blockRowSize() * BaseType::blockColSize() +
+ r * BaseType::blockColSize() * row_interleave_ + c];
+ if (std::is_integral<T>::value) {
+ // cast to int64 because cout doesn't print int8_t type directly
+ std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
+ } else {
+ std::cout << std::setw(5) << val << " ";
+ }
}
+ std::cout << std::endl;
}
std::cout << std::endl;
}
- std::cout << std::endl;
}
}
}