diff options
author | Kenneth Heafield <github@kheafield.com> | 2018-06-17 01:27:20 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2018-06-17 01:27:20 +0300 |
commit | dccc15444c9a0a9adfb2216d9a28635f0b5f0610 (patch) | |
tree | 964ee6db5e1629436a2c8a0f7b22395c31287e26 /avx2_gemm.h | |
parent | c90b2d8a81978715a30c90636f2a4fc6bcd37aaf (diff) |
Tested reshaping
Diffstat (limited to 'avx2_gemm.h')
-rw-r--r-- | avx2_gemm.h | 49 |
1 files changed, 36 insertions, 13 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h index 699383f..337ee44 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -1,23 +1,46 @@ #pragma once -#include <immintrin.h> -#include <cstddef> +#include <cstdint> namespace intgemm { #ifdef __AVX2__ -namespace AVX2 { -void Quantize16(const float *input, int16_t *output, float quant_mult, std::size_t size); -void Quantize8(const float *input, int8_t *output, float quant_mult, std::size_t size); +struct AVX2_16bit { + typedef int16_t Integer; -// Multiply C = unquant_mult * A * B^T. A is normally activations and B is normally a parameter matrix. -// Values of A and B should come from the corresponding quantizer. -// A, B, and C must be 32-byte alined. -void MatrixMult16(const __m256i *A, const __m256i *B, float *C, float unquant_mult, int num_A_rows, int num_B_rows, int width); -void MatrixMult8(const __m256i *A, const __m256i *B, float *C, float unquant_mult, int num_A_rows, int num_B_rows, int width); + // Currently A is prepared by quantization but this could theoretically change. + static void PrepareA(const float *input, int16_t *output, float quant_mult, int rows, int cols) { + Quantize(input, output, quant_mult, rows * cols); + } -void MatrixMult8Contrast(const __m256i *A, const __m256i *B, float *C, float unquant_mult, int num_A_rows, int num_B_rows, int width); -void MatrixMult8ASM(const __m256i *A, const __m256i *B, float *C, float unquant_mult, int num_A_rows, int num_B_rows, int width); + static void Quantize(const float *input, int16_t *output, float quant_mult, int size); + + // Tile size for B; B must be a multiple of this block size. + static const int kBTileRow = 16; + static const int kBTileCol = 8; + + static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols); + + static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); +}; + +struct AVX2_8bit { + typedef int8_t Integer; + + // Currently A is prepared by quantization but this could theoretically change. + static void PrepareA(const float *input, int8_t *output, float quant_mult, int rows, int cols) { + Quantize(input, output, quant_mult, rows * cols); + } + + static void Quantize(const float *input, int8_t *output, float quant_mult, int size); + + // Tile size for B; B must be a multiple of this block size. + static const int kBTileRow = 32; + static const int kBTileCol = 8; + + static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols); + + static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); +}; -} // namespace AVX2 #endif // __AVX2__ } // namespace intgemm |