diff options
Diffstat (limited to 'include/fbgemm/FbgemmFP16.h')
-rw-r--r-- | include/fbgemm/FbgemmFP16.h | 129 |
1 files changed, 77 insertions, 52 deletions
diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h index 55718d4..0428d93 100644 --- a/include/fbgemm/FbgemmFP16.h +++ b/include/fbgemm/FbgemmFP16.h @@ -10,6 +10,7 @@ // upgraded to match with new fbgemm interface. #include <cassert> +#include <cstdlib> #include <memory> #include <vector> @@ -22,7 +23,7 @@ namespace fbgemm2 { /// row-major format into /// internal packed blocked-row major format class PackedGemmMatrixFP16 { -public: + public: // takes smat input mamtrix in row-major format; // and packs it into gemm-friendly blocked format; // allocate space and sets up all the internal variables; @@ -32,30 +33,31 @@ public: // before flushing into fp32 // the smaller the brow_, the higher overhead // of flushing is - PackedGemmMatrixFP16(const matrix_op_t trans, const int nrow, - const int ncol, const float alpha, - const float *smat, - const int brow = 512) + PackedGemmMatrixFP16( + const matrix_op_t trans, + const int nrow, + const int ncol, + const float alpha, + const float* smat, + const int brow = 512) : nrow_(nrow), ncol_(ncol), brow_(brow) { - bcol_ = 8 * 1; // hardwired // set up internal packing parameters nbrow_ = ((numRows() % blockRowSize()) == 0) - ? (numRows() / blockRowSize()) - : ((numRows() + blockRowSize()) / blockRowSize()); + ? (numRows() / blockRowSize()) + : ((numRows() + blockRowSize()) / blockRowSize()); last_brow_ = ((nrow % blockRowSize()) == 0) ? blockRowSize() - : (nrow % blockRowSize()); + : (nrow % blockRowSize()); nbcol_ = ((numCols() % blockColSize()) == 0) - ? (numCols() / blockColSize()) - : ((numCols() + blockColSize()) / blockColSize()); + ? (numCols() / blockColSize()) + : ((numCols() + blockColSize()) / blockColSize()); if (numCols() != blockColSize() * nbcol_) { #ifdef VLOG - VLOG(0) - << "Packer warning: ncol(" << numCols() - << ") is not a multiple of internal block size (" << blockColSize() - << ")"; + VLOG(0) << "Packer warning: ncol(" << numCols() + << ") is not a multiple of internal block size (" + << blockColSize() << ")"; VLOG(0) << "lefover is currently done via MKL: hence overhead will inccur"; #endif @@ -64,7 +66,9 @@ public: // allocate and initialize packed memory const int padding = 1024; // required by sw pipelined kernels size_ = (blockRowSize() * nbrow_) * (blockColSize() * nbcol_); - pmat_ = (float16 *)aligned_alloc(64, matSize() * sizeof(float16) + padding); + // pmat_ = (float16 *)aligned_alloc(64, matSize() * sizeof(float16) + + // padding); + posix_memalign((void**)&pmat_, 64, matSize() * sizeof(float16) + padding); for (auto i = 0; i < matSize(); i++) { pmat_[i] = tconv(0.f, pmat_[i]); } @@ -77,7 +81,7 @@ public: free(pmat_); } -// protected: + // protected: // blocked row-major format address arithmetic uint64_t addr(const int r_, const int c_) const { uint64_t r = (uint64_t)r_; @@ -87,10 +91,9 @@ public: brow_offset = (block_row_id * nbcol_) * (blockRowSize() * blockColSize()); uint64_t block_col_id = c / blockColSize(), - bcol_offset = - block_col_id * ((block_row_id != nbrow_ - 1) - ? (blockRowSize() * blockColSize()) - : (last_brow_ * blockColSize())); + bcol_offset = block_col_id * + ((block_row_id != nbrow_ - 1) ? (blockRowSize() * blockColSize()) + : (last_brow_ * blockColSize())); uint64_t block_offset = brow_offset + bcol_offset; uint64_t inblock_offset = r % blockRowSize() * blockColSize() + c % blockColSize(); @@ -100,22 +103,22 @@ public: return index; } - void packFromSrc(const matrix_op_t trans, const float alpha, - const float *smat) { + void + packFromSrc(const matrix_op_t trans, const float alpha, const float* smat) { bool tr = (trans == matrix_op_t::Transpose); // pack for (int i = 0; i < numRows(); i++) { for (int j = 0; j < numCols(); j++) { pmat_[addr(i, j)] = tconv( - alpha * ( - (tr == false) - ? smat[i * numCols() + j] : smat[i + numRows() * j]), + alpha * + ((tr == false) ? smat[i * numCols() + j] + : smat[i + numRows() * j]), pmat_[addr(i, j)]); } } } - const float16 &operator()(const int r, const int c) const { + const float16& operator()(const int r, const int c) const { uint64_t a = addr(r, c); assert(r < numRows()); assert(c < numCols()); @@ -123,38 +126,60 @@ public: return pmat_[a]; } - int matSize() const { return size_; } - int numRows() const { return nrow_; } - int numCols() const { return ncol_; } - inline int blockRowSize() const { return brow_; } - inline int blockColSize() const { return bcol_; } + int matSize() const { + return size_; + } + int numRows() const { + return nrow_; + } + int numCols() const { + return ncol_; + } + inline int blockRowSize() const { + return brow_; + } + inline int blockColSize() const { + return bcol_; + } int nrow_, ncol_; int brow_, last_brow_, bcol_; int nbrow_, nbcol_; uint64_t size_; - float16 *pmat_; - - friend void cblas_gemm_compute(const matrix_op_t transa, const int m, - const float *A, - const PackedGemmMatrixFP16 &Bp, - const float beta, float *C); - friend void cblas_gemm_compute(const matrix_op_t transa, const int m, - const float *A, - const PackedGemmMatrixFP16 &Bp, - const float beta, float *C); + float16* pmat_; + + friend void cblas_gemm_compute( + const matrix_op_t transa, + const int m, + const float* A, + const PackedGemmMatrixFP16& Bp, + const float beta, + float* C); + friend void cblas_gemm_compute( + const matrix_op_t transa, + const int m, + const float* A, + const PackedGemmMatrixFP16& Bp, + const float beta, + float* C); }; /** * restrictions: transa == CblasNoTrans */ -extern void cblas_gemm_compute(const matrix_op_t transa, const int m, - const float *A, - const PackedGemmMatrixFP16 &Bp, - const float beta, float *C); -extern void cblas_gemm_compute(const matrix_op_t transa, const int m, - const float *A, - const PackedGemmMatrixFP16 &Bp, - const float beta, float *C); - -}; // namespace fbgemm +extern void cblas_gemm_compute( + const matrix_op_t transa, + const int m, + const float* A, + const PackedGemmMatrixFP16& Bp, + const float beta, + float* C); +extern void cblas_gemm_compute( + const matrix_op_t transa, + const int m, + const float* A, + const PackedGemmMatrixFP16& Bp, + const float beta, + float* C); + +}; // namespace fbgemm2 |