diff options
author | Kenneth Heafield <github@kheafield.com> | 2018-06-17 14:34:59 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2018-06-17 14:34:59 +0300 |
commit | fd63cf1a77e192e09e36eac0b8d74969f44a3342 (patch) | |
tree | f56dc08f5804efda6b77cd8f962f17a12c7497a5 /sse2_gemm.h | |
parent | 694c597d9d3572331b7bb16b5ab354653b938792 (diff) |
Started work on SSE2
Diffstat (limited to 'sse2_gemm.h')
-rw-r--r-- | sse2_gemm.h | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/sse2_gemm.h b/sse2_gemm.h new file mode 100644 index 0000000..2c42e74 --- /dev/null +++ b/sse2_gemm.h @@ -0,0 +1,46 @@ +#pragma once +#include <cstdint> + +namespace intgemm { +#ifdef __SSE2__ + +struct SSE2_16bit { + typedef int16_t Integer; + + // Currently A is prepared by quantization but this could theoretically change. + static void PrepareA(const float *input, int16_t *output, float quant_mult, int rows, int cols) { + Quantize(input, output, quant_mult, rows * cols); + } + + static void Quantize(const float *input, int16_t *output, float quant_mult, int size); + + // Tile size for B; B must be a multiple of this block size. + static const int kBTileRow = 8; + static const int kBTileCol = 8; + + static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols); + + static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); +}; + +struct SSE2_8bit { + typedef int8_t Integer; + + // Currently A is prepared by quantization but this could theoretically change. + static void PrepareA(const float *input, int8_t *output, float quant_mult, int rows, int cols) { + Quantize(input, output, quant_mult, rows * cols); + } + + static void Quantize(const float *input, int8_t *output, float quant_mult, int size); + + // Tile size for B; B must be a multiple of this block size. + static const int kBTileRow = 16; + static const int kBTileCol = 8; + + static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols); + + static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols); +}; + +#endif // __SSE2__ +} // namespace intgemm |