/* * Copyright (c) Facebook, Inc. and its affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include "fbgemm/ConvUtils.h" #include "fbgemm/Fbgemm.h" namespace fbgemm { template PackMatrix::PackMatrix( int32_t rows, int32_t cols, inpType* buf, int groups, const BlockingFactors* params) : buf_(buf), nrows_(rows), ncols_(cols), G_(groups) { bufAllocatedHere_ = false; blocking_params = params; if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } } template int PackMatrix::packedBufferSize( int rows, int cols, const BlockingFactors* params) { int MCB, KCB, NCB; if (params) { MCB = params->MCB; NCB = params->NCB; KCB = params->KCB; } else { if (fbgemmHasAvx512Support()) { MCB = PackingTraits::MCB; NCB = PackingTraits::NCB; KCB = PackingTraits::KCB; } else if (fbgemmHasAvx2Support()) { MCB = PackingTraits::MCB; NCB = PackingTraits::NCB; KCB = PackingTraits::KCB; } else { // TODO: Have default slower path assert(0 && "unsupported architecure"); return -1; } } if (fbgemmHasAvx512Support()) { if (isA()) { return MCB * KCB; } else { int rowBlock = KCB; int colBlock = NCB; return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * (((cols + colBlock - 1) / colBlock) * colBlock); } } else if (fbgemmHasAvx2Support()) { if (isA()) { return MCB * KCB; } else { int rowBlock = KCB; int colBlock = NCB; return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * (((cols + colBlock - 1) / colBlock) * colBlock); } } else { // TODO: Have default slower path assert(0 && "unsupported architecure"); } return -1; } // int32 accumulation template class PackMatrix, uint8_t, int32_t>; template class PackMatrix< PackAWithRowOffset, uint8_t, int32_t>; template class PackMatrix, uint8_t, int32_t>; template class PackMatrix< PackAWithIm2Col, uint8_t, int32_t>; template class PackMatrix< PackAWithQuantRowOffset, uint8_t, int32_t>; template class PackMatrix, int8_t, int32_t>; // int16 accumulation template class PackMatrix, uint8_t, int16_t>; template class PackMatrix< PackAWithIm2Col, uint8_t, int16_t>; template class PackMatrix< PackAWithRowOffset, uint8_t, int16_t>; template class PackMatrix, uint8_t, int16_t>; template class PackMatrix, int8_t, int16_t>; } // namespace fbgemm