diff options
-rw-r--r-- | README.md | 8 | ||||
-rw-r--r-- | benchmark.cc | 2 | ||||
-rw-r--r-- | example.cc | 12 | ||||
-rw-r--r-- | intgemm.cc | 12 | ||||
-rw-r--r-- | intgemm.h | 4 | ||||
-rw-r--r-- | test.cc | 4 |
6 files changed, 20 insertions, 22 deletions
@@ -28,13 +28,13 @@ Both A and B should be prepared before multiplication. * B is width x B_cols. */ /* Prepare A for multiplication. This might be offline or on the fly. */ -intgemm::Generic_16bit::PrepareA(A, A_prepared, quant_mult, A_rows, width); +intgemm::Int16::PrepareA(A, A_prepared, quant_mult, A_rows, width); /* Prepare B for multiplication. This is typically done offline. */ -intgemm::Generic_16bit::PrepareB(B, B_prepared, quant_mult, width, B_cols); +intgemm::Int16::PrepareB(B, B_prepared, quant_mult, width, B_cols); /* Multiply and produce results in C */ -intgemm::Generic_16bit::Multiply(A_prepared, B_prepared, C, 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); +intgemm::Int16::Multiply(A_prepared, B_prepared, C, 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); ``` -For 8-bit, use `Generic_8bit` instead of `Generic_16bit`. +For 8-bit, use `Int8` instead of `Int16`. When repesented as floats, all of A, B, and C are in row-major format. diff --git a/benchmark.cc b/benchmark.cc index 014cb4d..be07e9a 100644 --- a/benchmark.cc +++ b/benchmark.cc @@ -66,13 +66,11 @@ void Time(int A_rows, int width, int B_cols, int repeat = 20) { #ifdef __AVX512BW__ Run<AVX512_8bit>(m, repeat); #endif - Run<Generic_8bit>(m, repeat); Run<SSE2_16bit>(m, repeat); Run<AVX2_16bit>(m, repeat); #ifdef __AVX512BW__ Run<AVX512_16bit>(m, repeat); #endif - Run<Generic_16bit>(m, repeat); } } // namespace intgemm @@ -41,14 +41,14 @@ int main() { AlignedVector<int16_t> A_prepared(A_rows * width); AlignedVector<int16_t> B_prepared(width * B_cols); // Quantize A. - intgemm::Generic_16bit::PrepareA(A.get(), A_prepared.get(), quant_mult, A_rows, width); + intgemm::Int16::PrepareA(A.get(), A_prepared.get(), quant_mult, A_rows, width); // Quantize and reshape B. // Typically you will do this once when parameters are loaded, not every time. - intgemm::Generic_16bit::PrepareB(B.get(), B_prepared.get(), quant_mult, width, B_cols); + intgemm::Int16::PrepareB(B.get(), B_prepared.get(), quant_mult, width, B_cols); AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. - intgemm::Generic_16bit::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); + intgemm::Int16::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } @@ -60,14 +60,14 @@ int main() { AlignedVector<int8_t> A_prepared(A_rows * width); AlignedVector<int8_t> B_prepared(width * B_cols); // Quantize A. - intgemm::Generic_8bit::PrepareA(A.get(), A_prepared.get(), quant_mult, A_rows, width); + intgemm::Int8::PrepareA(A.get(), A_prepared.get(), quant_mult, A_rows, width); // Quantize and reshape B. // Typically you will do this once when parameters are loaded, not every time. - intgemm::Generic_8bit::PrepareB(B.get(), B_prepared.get(), quant_mult, width, B_cols); + intgemm::Int8::PrepareB(B.get(), B_prepared.get(), quant_mult, width, B_cols); AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. - intgemm::Generic_8bit::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); + intgemm::Int8::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } @@ -69,12 +69,12 @@ template <class T> T ChooseCPU(T avx512, T avx2, T ssse3, T sse2, T unsupported) } // namespace -void (*Generic_16bit::Quantize)(const float *input, int16_t *output, float quant_mult, int size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit::Quantize); -void (*Generic_16bit::PrepareB)(const float *input, int16_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit::PrepareB); -void (*Generic_16bit::Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_16bit::Multiply, AVX2_16bit::Multiply, SSE2_16bit::Multiply, SSE2_16bit::Multiply, Unsupported_16bit::Multiply); +void (*Int16::Quantize)(const float *input, int16_t *output, float quant_mult, int size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit::Quantize); +void (*Int16::PrepareB)(const float *input, int16_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit::PrepareB); +void (*Int16::Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_16bit::Multiply, AVX2_16bit::Multiply, SSE2_16bit::Multiply, SSE2_16bit::Multiply, Unsupported_16bit::Multiply); -void (*Generic_8bit::Quantize)(const float *input, int8_t *output, float quant_mult, int size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize); -void (*Generic_8bit::PrepareB)(const float *input, int8_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB); -void (*Generic_8bit::Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_8bit::Multiply, AVX2_8bit::Multiply, SSSE3_8bit::Multiply, Unsupported_8bit::Multiply, Unsupported_8bit::Multiply); +void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, int size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize); +void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB); +void (*Int8::Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_8bit::Multiply, AVX2_8bit::Multiply, SSSE3_8bit::Multiply, Unsupported_8bit::Multiply, Unsupported_8bit::Multiply); } // namespace intgemm @@ -58,7 +58,7 @@ class UnsupportedCPU : public std::exception { }; /* 16-bit matrix multiplication. */ -struct Generic_16bit { +struct Int16 { typedef int16_t Integer; // A's size must be a multiple of 1x32. @@ -90,7 +90,7 @@ struct Generic_16bit { }; /* 8-bit matrix multiplication */ -struct Generic_8bit { +struct Int8 { typedef int8_t Integer; // A's size must be a multiple of 1x64. @@ -220,12 +220,12 @@ template <class Routine> void TestMultiply(int A_rows, int width, int B_cols) { } void TestBoth(int A_rows, int width, int B_cols) { - if (Generic_16bit::Quantize == AVX512_16bit::Quantize) { + if (Int16::Quantize == AVX512_16bit::Quantize) { TestMultiply<AVX512_16bit>(A_rows, width, B_cols); } TestMultiply<AVX2_16bit>(A_rows, width, B_cols); TestMultiply<SSE2_16bit>(A_rows, width, B_cols); - if (Generic_16bit::Quantize == AVX512_16bit::Quantize) { + if (Int16::Quantize == AVX512_16bit::Quantize) { TestMultiply<AVX512_8bit>(A_rows, width, B_cols); } TestMultiply<AVX2_8bit>(A_rows, width, B_cols); |