From 0ae900a954ee19ebb8d33ba57493e4a14249c719 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sat, 23 Jun 2018 20:27:52 +0100 Subject: Name Generic_16bit to Int16, Generic_8bit to Int8 --- README.md | 8 ++++---- benchmark.cc | 2 -- example.cc | 12 ++++++------ intgemm.cc | 12 ++++++------ intgemm.h | 4 ++-- test.cc | 4 ++-- 6 files changed, 20 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 46a9183..33728a5 100644 --- a/README.md +++ b/README.md @@ -28,13 +28,13 @@ Both A and B should be prepared before multiplication. * B is width x B_cols. */ /* Prepare A for multiplication. This might be offline or on the fly. */ -intgemm::Generic_16bit::PrepareA(A, A_prepared, quant_mult, A_rows, width); +intgemm::Int16::PrepareA(A, A_prepared, quant_mult, A_rows, width); /* Prepare B for multiplication. This is typically done offline. */ -intgemm::Generic_16bit::PrepareB(B, B_prepared, quant_mult, width, B_cols); +intgemm::Int16::PrepareB(B, B_prepared, quant_mult, width, B_cols); /* Multiply and produce results in C */ -intgemm::Generic_16bit::Multiply(A_prepared, B_prepared, C, 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); +intgemm::Int16::Multiply(A_prepared, B_prepared, C, 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); ``` -For 8-bit, use `Generic_8bit` instead of `Generic_16bit`. +For 8-bit, use `Int8` instead of `Int16`. When repesented as floats, all of A, B, and C are in row-major format. diff --git a/benchmark.cc b/benchmark.cc index 014cb4d..be07e9a 100644 --- a/benchmark.cc +++ b/benchmark.cc @@ -66,13 +66,11 @@ void Time(int A_rows, int width, int B_cols, int repeat = 20) { #ifdef __AVX512BW__ Run(m, repeat); #endif - Run(m, repeat); Run(m, repeat); Run(m, repeat); #ifdef __AVX512BW__ Run(m, repeat); #endif - Run(m, repeat); } } // namespace intgemm diff --git a/example.cc b/example.cc index c67e48b..0072596 100644 --- a/example.cc +++ b/example.cc @@ -41,14 +41,14 @@ int main() { AlignedVector A_prepared(A_rows * width); AlignedVector B_prepared(width * B_cols); // Quantize A. - intgemm::Generic_16bit::PrepareA(A.get(), A_prepared.get(), quant_mult, A_rows, width); + intgemm::Int16::PrepareA(A.get(), A_prepared.get(), quant_mult, A_rows, width); // Quantize and reshape B. // Typically you will do this once when parameters are loaded, not every time. - intgemm::Generic_16bit::PrepareB(B.get(), B_prepared.get(), quant_mult, width, B_cols); + intgemm::Int16::PrepareB(B.get(), B_prepared.get(), quant_mult, width, B_cols); AlignedVector C(A_rows * B_cols); // Do the actual multiply. - intgemm::Generic_16bit::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); + intgemm::Int16::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } @@ -60,14 +60,14 @@ int main() { AlignedVector A_prepared(A_rows * width); AlignedVector B_prepared(width * B_cols); // Quantize A. - intgemm::Generic_8bit::PrepareA(A.get(), A_prepared.get(), quant_mult, A_rows, width); + intgemm::Int8::PrepareA(A.get(), A_prepared.get(), quant_mult, A_rows, width); // Quantize and reshape B. // Typically you will do this once when parameters are loaded, not every time. - intgemm::Generic_8bit::PrepareB(B.get(), B_prepared.get(), quant_mult, width, B_cols); + intgemm::Int8::PrepareB(B.get(), B_prepared.get(), quant_mult, width, B_cols); AlignedVector C(A_rows * B_cols); // Do the actual multiply. - intgemm::Generic_8bit::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); + intgemm::Int8::Multiply(A_prepared.get(), B_prepared.get(), C.get(), 1.0 / (quant_mult * quant_mult), A_rows, width, B_cols); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } diff --git a/intgemm.cc b/intgemm.cc index bdf2f04..7653a85 100644 --- a/intgemm.cc +++ b/intgemm.cc @@ -69,12 +69,12 @@ template T ChooseCPU(T avx512, T avx2, T ssse3, T sse2, T unsupported) } // namespace -void (*Generic_16bit::Quantize)(const float *input, int16_t *output, float quant_mult, int size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit::Quantize); -void (*Generic_16bit::PrepareB)(const float *input, int16_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit::PrepareB); -void (*Generic_16bit::Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_16bit::Multiply, AVX2_16bit::Multiply, SSE2_16bit::Multiply, SSE2_16bit::Multiply, Unsupported_16bit::Multiply); +void (*Int16::Quantize)(const float *input, int16_t *output, float quant_mult, int size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit::Quantize); +void (*Int16::PrepareB)(const float *input, int16_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit::PrepareB); +void (*Int16::Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_16bit::Multiply, AVX2_16bit::Multiply, SSE2_16bit::Multiply, SSE2_16bit::Multiply, Unsupported_16bit::Multiply); -void (*Generic_8bit::Quantize)(const float *input, int8_t *output, float quant_mult, int size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize); -void (*Generic_8bit::PrepareB)(const float *input, int8_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB); -void (*Generic_8bit::Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_8bit::Multiply, AVX2_8bit::Multiply, SSSE3_8bit::Multiply, Unsupported_8bit::Multiply, Unsupported_8bit::Multiply); +void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, int size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize); +void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB); +void (*Int8::Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_8bit::Multiply, AVX2_8bit::Multiply, SSSE3_8bit::Multiply, Unsupported_8bit::Multiply, Unsupported_8bit::Multiply); } // namespace intgemm diff --git a/intgemm.h b/intgemm.h index 016acd6..ea651b9 100644 --- a/intgemm.h +++ b/intgemm.h @@ -58,7 +58,7 @@ class UnsupportedCPU : public std::exception { }; /* 16-bit matrix multiplication. */ -struct Generic_16bit { +struct Int16 { typedef int16_t Integer; // A's size must be a multiple of 1x32. @@ -90,7 +90,7 @@ struct Generic_16bit { }; /* 8-bit matrix multiplication */ -struct Generic_8bit { +struct Int8 { typedef int8_t Integer; // A's size must be a multiple of 1x64. diff --git a/test.cc b/test.cc index 58e12e0..bd0ccb6 100644 --- a/test.cc +++ b/test.cc @@ -220,12 +220,12 @@ template void TestMultiply(int A_rows, int width, int B_cols) { } void TestBoth(int A_rows, int width, int B_cols) { - if (Generic_16bit::Quantize == AVX512_16bit::Quantize) { + if (Int16::Quantize == AVX512_16bit::Quantize) { TestMultiply(A_rows, width, B_cols); } TestMultiply(A_rows, width, B_cols); TestMultiply(A_rows, width, B_cols); - if (Generic_16bit::Quantize == AVX512_16bit::Quantize) { + if (Int16::Quantize == AVX512_16bit::Quantize) { TestMultiply(A_rows, width, B_cols); } TestMultiply(A_rows, width, B_cols); -- cgit v1.2.3