Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2019-02-25 18:11:26 +0300
committerKenneth Heafield <github@kheafield.com>2019-02-25 18:11:26 +0300
commit43b36d6d39a7b40cefa936c9b0ce4007a2d987e7 (patch)
tree286c4d4dea8a79498471658fbba227e40e2e5d77
parenta63b5de54441b0c632866edd525b7c8bbfe2c094 (diff)
Genericize index type
-rw-r--r--avx2_gemm.cc20
-rw-r--r--avx2_gemm.h30
-rw-r--r--avx512_gemm.cc20
-rw-r--r--avx512_gemm.h30
-rw-r--r--benchmark.cc6
-rw-r--r--example.cc7
-rw-r--r--interleave.h6
-rw-r--r--intgemm.cc34
-rw-r--r--intgemm.h38
-rw-r--r--multiply.h4
-rw-r--r--multiply_test.cc22
-rw-r--r--sse2_gemm.cc8
-rw-r--r--sse2_gemm.h16
-rw-r--r--ssse3_gemm.cc10
-rw-r--r--ssse3_gemm.h16
-rw-r--r--types.h (renamed from cpu_type.h)2
16 files changed, 137 insertions, 132 deletions
diff --git a/avx2_gemm.cc b/avx2_gemm.cc
index 9fa05c4..455542d 100644
--- a/avx2_gemm.cc
+++ b/avx2_gemm.cc
@@ -29,7 +29,7 @@ class QuantizeTile16 {
return Tile(input, input + 8);
}
- Integer ForReshape(const float *input, int cols) {
+ Integer ForReshape(const float *input, Index cols) {
// 8 rows in the first 128-bit register, 8 in the second register.
return Tile(input, input + 8 * cols);
}
@@ -50,7 +50,7 @@ class QuantizeTile16 {
} // namespace
// Just quantize everything in order.
-void AVX2_16bit::Quantize(const float *input, int16_t *output, float quant_mult, int size) {
+void AVX2_16bit::Quantize(const float *input, int16_t *output, float quant_mult, Index size) {
assert(size % 16 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 32 == 0);
QuantizeTile16 q(quant_mult);
@@ -75,7 +75,7 @@ class QuantizeTile8 {
return Tile(input, input + 8, input + 16, input + 24);
}
- inline __m256i ForReshape(const float *input, int cols) {
+ inline __m256i ForReshape(const float *input, Index cols) {
// Put higher rows in the second half of the register. These will jumble
// around in the same way then conveniently land in the right place.
return Tile(input, input + 2 * cols, input + 16 * cols, input + 18 * cols);
@@ -110,7 +110,7 @@ class QuantizeTile8 {
} // namespace
// Just quantize everything in order.
-void AVX2_8bit::Quantize(const float *input, int8_t *output, float quant_mult, int size) {
+void AVX2_8bit::Quantize(const float *input, int8_t *output, float quant_mult, Index size) {
assert(size % 32 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 32 == 0);
QuantizeTile8 q(quant_mult);
@@ -120,27 +120,27 @@ void AVX2_8bit::Quantize(const float *input, int8_t *output, float quant_mult, i
}
}
-void AVX2_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols) {
+void AVX2_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols);
}
-void AVX2_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) {
+void AVX2_16bit::SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end);
}
-void AVX2_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
+void AVX2_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols);
}
-void AVX2_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) {
+void AVX2_8bit::SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end);
}
-void AVX2_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+void AVX2_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
Multiply16<__m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols);
}
-void AVX2_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+void AVX2_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
Multiply8_SSE2OrAVX2<Multiply8_AVXAVX2, __m256i, __m256>(A, B, C, unquant_mult, A_rows, width, B_cols);
}
diff --git a/avx2_gemm.h b/avx2_gemm.h
index 5af2a81..4b0b001 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -1,5 +1,5 @@
#pragma once
-#include "cpu_type.h"
+#include "types.h"
#include <cstdint>
#include <stdint.h>
@@ -9,21 +9,21 @@ struct AVX2_16bit {
typedef int16_t Integer;
// Currently A is prepared by quantization but this could theoretically change.
- static inline void PrepareA(const float *input, int16_t *output, float quant_mult, int rows, int cols) {
+ static inline void PrepareA(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
Quantize(input, output, quant_mult, rows * cols);
}
- static void Quantize(const float *input, int16_t *output, float quant_mult, int size);
+ static void Quantize(const float *input, int16_t *output, float quant_mult, Index size);
// Tile size for B; B must be a multiple of this block size.
- static const int kBTileRow = 16;
- static const int kBTileCol = 8;
+ static const Index kBTileRow = 16;
+ static const Index kBTileCol = 8;
- static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols);
+ static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols);
- static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end);
+ static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
- static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
+ static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols);
static const char *const kName;
@@ -34,21 +34,21 @@ struct AVX2_8bit {
typedef int8_t Integer;
// Currently A is prepared by quantization but this could theoretically change.
- static inline void PrepareA(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
+ static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
Quantize(input, output, quant_mult, rows * cols);
}
- static void Quantize(const float *input, int8_t *output, float quant_mult, int size);
+ static void Quantize(const float *input, int8_t *output, float quant_mult, Index size);
// Tile size for B; B must be a multiple of this block size.
- static const int kBTileRow = 32;
- static const int kBTileCol = 8;
+ static const Index kBTileRow = 32;
+ static const Index kBTileCol = 8;
- static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols);
+ static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols);
- static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end);
+ static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
- static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
+ static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols);
static const char *const kName;
diff --git a/avx512_gemm.cc b/avx512_gemm.cc
index e7df5ad..a2042c6 100644
--- a/avx512_gemm.cc
+++ b/avx512_gemm.cc
@@ -35,7 +35,7 @@ inline __m512i QuantizerGrab(const float *input, const __m512 quant_mult_reg) {
// rearranging B.
//
// Convert to 16-bit signed integers.
-void AVX512_16bit::Quantize(const float *input, int16_t *output, float quant_mult, int size) {
+void AVX512_16bit::Quantize(const float *input, int16_t *output, float quant_mult, Index size) {
assert(size % 16 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
// Fill with the quantization multiplier.
@@ -48,7 +48,7 @@ void AVX512_16bit::Quantize(const float *input, int16_t *output, float quant_mul
}
// Convert to 8-bit signed integers.
-void AVX512_8bit::Quantize(const float *input, int8_t *output, float quant_mult, int size) {
+void AVX512_8bit::Quantize(const float *input, int8_t *output, float quant_mult, Index size) {
assert(size % 16 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
const __m512i neg127 = _mm512_set1_epi32(-127);
@@ -91,7 +91,7 @@ class QuantizeTile16 {
explicit QuantizeTile16(float mult) : mult_reg_(_mm512_set1_ps(mult)) {}
- inline __m512i ForReshape(const float *input, int cols) {
+ inline __m512i ForReshape(const float *input, Index cols) {
__m512i g0 = QuantizerGrabHalves(input, input + 16 * cols, mult_reg_);
__m512i g1 = QuantizerGrabHalves(input + 8 * cols, input + 24 * cols, mult_reg_);
__m512i packed = _mm512_packs_epi32(g0, g1);
@@ -109,7 +109,7 @@ class QuantizeTile8 {
explicit QuantizeTile8(float mult) : mult_reg_(_mm512_set1_ps(mult)) {}
- inline __m512i ForReshape(const float *input, int cols) {
+ inline __m512i ForReshape(const float *input, Index cols) {
// TODO: try alternative: _mm512_cvtsepi32_epi8 ?
const __m512i neg127 = _mm512_set1_epi8(-127);
// In reverse order: grabbing the first 32-bit values from each 128-bit register, then the second 32-bit values, etc.
@@ -137,30 +137,30 @@ class QuantizeTile8 {
} // namespace
-void AVX512_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols) {
+void AVX512_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols);
}
-void AVX512_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) {
+void AVX512_16bit::SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows * 2, cols_begin, cols_end);
}
-void AVX512_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
+void AVX512_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols);
}
-void AVX512_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) {
+void AVX512_8bit::SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m512i*)input, (__m512i*)output, rows, cols_begin, cols_end);
}
-void AVX512_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+void AVX512_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
// The unquantization is only 256-bit wide because there are 8 results.
Multiply16<__m512i, __m256> (A, B, C, unquant_mult, A_rows, width, B_cols);
}
// Special AVX512 implementation due to having 32 registers (so I don't have to
// allocate registers manually) and no sign instruction.
-void AVX512_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+void AVX512_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
typedef __m512i Integer;
typedef __m256 Float; // For quantization we only do 8 at a time.
// This is copy-paste from Multiply8_SSE2OrAVX2.
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 56efdbb..f9b0f81 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -2,7 +2,7 @@
#include <stdint.h>
#include <cstdint>
-#include "cpu_type.h"
+#include "types.h"
/* AVX512 implementation.
* This uses AVX512BW, AVX512DQ, and might use AVX512VL
@@ -20,24 +20,24 @@ struct AVX512_16bit {
// Currently A is prepared by quantization but this could theoretically change.
// rows * cols must be a multiple of 16.
- static inline void PrepareA(const float *input, int16_t *output, float quant_mult, int rows, int cols) {
+ static inline void PrepareA(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
Quantize(input, output, quant_mult, rows * cols);
}
// Technically output can be unaligned in Quantize.
// But then it will need to be aligned for Multiply.
// size must be a multiple of 16.
- static void Quantize(const float *input, int16_t *output, float quant_mult, int size);
+ static void Quantize(const float *input, int16_t *output, float quant_mult, Index size);
// Tile size for B; B must be a multiple of this block size.
- static const int kBTileRow = 32;
- static const int kBTileCol = 8;
+ static const Index kBTileRow = 32;
+ static const Index kBTileCol = 8;
- static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols);
+ static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols);
- static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end);
+ static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
- static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
+ static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols);
static const char *const kName;
@@ -48,23 +48,23 @@ struct AVX512_8bit {
typedef int8_t Integer;
// Currently A is prepared by quantization but this could theoretically change.
- static inline void PrepareA(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
+ static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
Quantize(input, output, quant_mult, rows * cols);
}
// Technically output can be unaligned in Quantize.
// But then it will need to be aligned for Multiply.
- static void Quantize(const float *input, int8_t *output, float quant_mult, int size);
+ static void Quantize(const float *input, int8_t *output, float quant_mult, Index size);
// Tile size for B; B must be a multiple of this block size.
- static const int kBTileRow = 64;
- static const int kBTileCol = 8;
+ static const Index kBTileRow = 64;
+ static const Index kBTileCol = 8;
- static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols);
+ static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols);
- static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end);
+ static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
- static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
+ static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols);
static const char *const kName;
diff --git a/benchmark.cc b/benchmark.cc
index bcefda8..a629049 100644
--- a/benchmark.cc
+++ b/benchmark.cc
@@ -25,7 +25,7 @@ float MaxAbsoluteBaseline(const float *begin, const float *end) {
}
void BenchmarkMaxAbsolute() {
- const int size = 4096 * 4096;
+ const Index size = 4096 * 4096;
AlignedVector<float> v(size);
for (int i = 0; i < size; ++i) {
v[i] = (float)rand() / (float)RAND_MAX;
@@ -46,7 +46,7 @@ void BenchmarkMaxAbsolute() {
}
struct RandomMatrices {
- RandomMatrices(int A_rows_in, int width_in, int B_cols_in) :
+ RandomMatrices(Index A_rows_in, Index width_in, Index B_cols_in) :
A_rows(A_rows_in), width(width_in), B_cols(B_cols_in),
A(A_rows * width), B(width * B_cols) {
for (int i = 0; i < A_rows * width; i++) {
@@ -58,7 +58,7 @@ struct RandomMatrices {
}
}
- const int A_rows, width, B_cols;
+ const Index A_rows, width, B_cols;
AlignedVector<float> A, B;
};
diff --git a/example.cc b/example.cc
index 0072596..4a9f244 100644
--- a/example.cc
+++ b/example.cc
@@ -8,10 +8,11 @@
#include <math.h>
int main() {
- const int A_rows = 1;
+ using intgemm::Index;
+ const Index A_rows = 1;
// The shared dimension: A's columns and B's rows.
- const int width = 64;
- const int B_cols = 8;
+ const Index width = 64;
+ const Index B_cols = 8;
// This is a simple vector class that allocates memory aligned to 64 bytes.
// You don't have to use it; just use aligned_alloc and friends directly.
diff --git a/interleave.h b/interleave.h
index 7bf89b1..f654dd2 100644
--- a/interleave.h
+++ b/interleave.h
@@ -196,7 +196,7 @@ template <class Register> static inline void Transpose8InLane(
// 256 272
// 257 273
// ... ...
-template <class Quantizer> static inline void PrepareBFor8(const float *input, int8_t *output_shadow, Quantizer q, int rows, int cols) {
+template <class Quantizer> static inline void PrepareBFor8(const float *input, int8_t *output_shadow, Quantizer q, Index rows, Index cols) {
typedef typename Quantizer::Integer Register;
// Currently all multipliers have a stride of 8 columns.
const int kColStride = 8;
@@ -229,7 +229,7 @@ template <class Quantizer> static inline void PrepareBFor8(const float *input, i
}
}
-template <class Quantizer> static inline void PrepareBFor16(const float *input, int16_t *output_shadow, Quantizer q, int rows, int cols) {
+template <class Quantizer> static inline void PrepareBFor16(const float *input, int16_t *output_shadow, Quantizer q, Index rows, Index cols) {
typedef typename Quantizer::Integer Register;
assert(cols % 8 == 0);
assert(rows % (sizeof(Register) / sizeof(int16_t)) == 0);
@@ -250,7 +250,7 @@ template <class Quantizer> static inline void PrepareBFor16(const float *input,
/* Select columns of B from PrepareB format to PrepareB format.
*/
-template <class Register> static inline void SelectColumnsOfB(const Register *input, Register *output, int rows_bytes /* number of bytes in a row */, const std::size_t *cols_begin, const std::size_t *cols_end) {
+template <class Register> static inline void SelectColumnsOfB(const Register *input, Register *output, Index rows_bytes /* number of bytes in a row */, const Index *cols_begin, const Index *cols_end) {
assert(rows_bytes % sizeof(Register) == 0);
assert((cols_end - cols_begin) % 8 == 0);
// Do columns for multiples of 8.
diff --git a/intgemm.cc b/intgemm.cc
index d3cc976..d286c1a 100644
--- a/intgemm.cc
+++ b/intgemm.cc
@@ -1,6 +1,6 @@
#include "intgemm.h"
-#include "cpu_type.h"
+#include "types.h"
#include "sse2_gemm.h"
#include "ssse3_gemm.h"
#include "avx2_gemm.h"
@@ -21,16 +21,16 @@ const char *UnsupportedCPU::what() const throw() {
namespace {
struct Unsupported_16bit {
- static void Quantize(const float *, int16_t *, float, int) {
+ static void Quantize(const float *, int16_t *, float, Index) {
throw UnsupportedCPU();
}
- static void PrepareB(const float *, int16_t *, float, int, int) {
+ static void PrepareB(const float *, int16_t *, float, Index, Index) {
throw UnsupportedCPU();
}
- static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) {
+ static void SelectColumnsB(const int16_t *, int16_t *, Index, const Index *, const Index *) {
throw UnsupportedCPU();
}
- static void Multiply(const int16_t *, const int16_t *, float *C, float, int, int, int) {
+ static void Multiply(const int16_t *, const int16_t *, float *, float, Index, Index, Index) {
throw UnsupportedCPU();
}
static const char *const kName;
@@ -38,16 +38,16 @@ struct Unsupported_16bit {
const char *const Unsupported_16bit::kName = "16-bit Unsupported";
struct Unsupported_8bit {
- static void Quantize(const float *, int8_t *, float, int) {
+ static void Quantize(const float *, int8_t *, float, Index) {
throw UnsupportedCPU();
}
- static void PrepareB(const float *, int8_t *, float, int, int) {
+ static void PrepareB(const float *, int8_t *, float, Index, Index) {
throw UnsupportedCPU();
}
- static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) {
+ static void SelectColumnsB(const int8_t *, int8_t *, Index, const Index *, const Index *) {
throw UnsupportedCPU();
}
- static void Multiply(const int8_t *, const int8_t *, float *C, float, int, int, int) {
+ static void Multiply(const int8_t *, const int8_t *, float *, float, Index, Index, Index) {
throw UnsupportedCPU();
}
static const char *const kName;
@@ -99,16 +99,16 @@ float AVX512_MaxAbsolute(const float *begin, const float *end) {
} // namespace
-void (*Int16::Quantize)(const float *input, int16_t *output, float quant_mult, int size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit::Quantize);
-void (*Int16::PrepareB)(const float *input, int16_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit::PrepareB);
-void (*Int16::SelectColumnsB)(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) = ChooseCPU(AVX512_16bit::SelectColumnsB, AVX2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, Unsupported_16bit::SelectColumnsB);
-void (*Int16::Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_16bit::Multiply, AVX2_16bit::Multiply, SSE2_16bit::Multiply, SSE2_16bit::Multiply, Unsupported_16bit::Multiply);
+void (*Int16::Quantize)(const float *input, int16_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_16bit::Quantize, AVX2_16bit::Quantize, SSE2_16bit::Quantize, SSE2_16bit::Quantize, Unsupported_16bit::Quantize);
+void (*Int16::PrepareB)(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_16bit::PrepareB, AVX2_16bit::PrepareB, SSE2_16bit::PrepareB, SSE2_16bit::PrepareB, Unsupported_16bit::PrepareB);
+void (*Int16::SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_16bit::SelectColumnsB, AVX2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, SSE2_16bit::SelectColumnsB, Unsupported_16bit::SelectColumnsB);
+void (*Int16::Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_16bit::Multiply, AVX2_16bit::Multiply, SSE2_16bit::Multiply, SSE2_16bit::Multiply, Unsupported_16bit::Multiply);
const char *const Int16::kName = ChooseCPU(AVX512_16bit::kName, AVX2_16bit::kName, SSE2_16bit::kName, SSE2_16bit::kName, Unsupported_16bit::kName);
-void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, int size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize);
-void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, int rows, int cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB);
-void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) = ChooseCPU(AVX512_8bit::SelectColumnsB, AVX2_8bit::SelectColumnsB, SSSE3_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB);
-void (*Int8::Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) = ChooseCPU(AVX512_8bit::Multiply, AVX2_8bit::Multiply, SSSE3_8bit::Multiply, Unsupported_8bit::Multiply, Unsupported_8bit::Multiply);
+void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512_8bit::Quantize, AVX2_8bit::Quantize, SSSE3_8bit::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize);
+void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512_8bit::PrepareB, AVX2_8bit::PrepareB, SSSE3_8bit::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB);
+void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512_8bit::SelectColumnsB, AVX2_8bit::SelectColumnsB, SSSE3_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB);
+void (*Int8::Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) = ChooseCPU(AVX512_8bit::Multiply, AVX2_8bit::Multiply, SSSE3_8bit::Multiply, Unsupported_8bit::Multiply, Unsupported_8bit::Multiply);
const char *const Int8::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit::kName, Unsupported_8bit::kName);
const CPUType kCPU = ChooseCPU(CPU_AVX512BW, CPU_AVX2, CPU_SSSE3, CPU_SSE2, CPU_UNSUPPORTED);
diff --git a/intgemm.h b/intgemm.h
index 07ff62a..58fa8cc 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -45,6 +45,8 @@
#include <stdint.h>
#include <exception>
+#include "types.h"
+
/* Dispatch to functions based on runtime CPUID. This adds one call-by-variable to each call. */
namespace intgemm {
@@ -64,32 +66,32 @@ struct Int16 {
typedef int16_t Integer;
// A's size must be a multiple of 1x32.
- static const int kATileRow = 1;
- static const int kATileCol = 32;
+ static const Index kATileRow = 1;
+ static const Index kATileCol = 32;
// B's size must be a multiple of 32x8.
- static const int kBTileRow = 32;
- static const int kBTileCol = 8;
+ static const Index kBTileRow = 32;
+ static const Index kBTileCol = 8;
// Currently A is prepared by quantization but this could theoretically change.
// A's columns must be a multiple of 8.
// The number of rows is anything.
- static inline void PrepareA(const float *input, int16_t *output, float quant_mult, int rows, int cols) {
+ static inline void PrepareA(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
Quantize(input, output, quant_mult, rows * cols);
}
// Multiply floats by quant_mult then convert to 16-bit integers with saturation.
// input
- static void (*Quantize)(const float *input, int16_t *output, float quant_mult, int size);
+ static void (*Quantize)(const float *input, int16_t *output, float quant_mult, Index size);
// Warning: the output of PrepareB depends on the CPU.
// It will match the Multiply function on the same CPU though.
- static void (*PrepareB)(const float *input, int16_t *output, float quant_mult, int rows, int cols);
+ static void (*PrepareB)(const float *input, int16_t *output, float quant_mult, Index rows, Index cols);
// Select columns from a prepared B matrix. The number of selected columns must be a multiple of 8.
- static void (*SelectColumnsB)(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end);
+ static void (*SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
// Multiply C = A * B, presuming A and B have been prepared.
- static void (*Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
+ static void (*Multiply)(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols);
static const char *const kName;
};
@@ -99,31 +101,31 @@ struct Int8 {
typedef int8_t Integer;
// A's size must be a multiple of 1x64.
- static const int kATileRow = 1;
- static const int kATileCol = 64;
+ static const Index kATileRow = 1;
+ static const Index kATileCol = 64;
// B's size must be a multiple of 64x8.
- static const int kBTileRow = 64;
- static const int kBTileCol = 8;
+ static const Index kBTileRow = 64;
+ static const Index kBTileCol = 8;
// Currently A is prepared by quantization but this could theoretically change.
// A's columns must be a multiple of 8.
// The number of rows is anything.
- static inline void PrepareA(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
+ static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
Quantize(input, output, quant_mult, rows * cols);
}
// Multiply floats by quant_mult then convert to 8-bit integers with saturation.
- static void (*Quantize)(const float *input, int8_t *output, float quant_mult, int size);
+ static void (*Quantize)(const float *input, int8_t *output, float quant_mult, Index size);
// Warning: the output of PrepareB depends on the CPU.
// It will match the Multiply function on the same CPU though.
- static void (*PrepareB)(const float *input, int8_t *output, float quant_mult, int rows, int cols);
+ static void (*PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols);
// Select columns from a prepared B matrix. The number of selected columns must be a multiple of 8.
- static void (*SelectColumnsB)(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end);
+ static void (*SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
// Multiply C = A * B, presuming A and B have been prepared.
- static void (*Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
+ static void (*Multiply)(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols);
static const char *const kName;
};
diff --git a/multiply.h b/multiply.h
index de1d677..931cc4b 100644
--- a/multiply.h
+++ b/multiply.h
@@ -204,7 +204,7 @@ template <class Register> inline Register Pack0123(Register sum0, Register sum1,
// A_rows can be anything non-negative.
// width must be a multiple of the register size.
// B_cols must be a multiple of 8.
-template <class Integer, class Float> inline void Multiply16(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+template <class Integer, class Float> inline void Multiply16(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
assert(width % (sizeof(Integer) / sizeof(int16_t)) == 0);
assert(B_cols % 8 == 0);
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
@@ -409,7 +409,7 @@ struct Multiply8_C {
}
};
-template <class Algo, class Integer, class Float> inline void Multiply8_SSE2OrAVX2(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+template <class Algo, class Integer, class Float> inline void Multiply8_SSE2OrAVX2(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
assert(width % sizeof(Integer) == 0);
assert(B_cols % 8 == 0);
assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
diff --git a/multiply_test.cc b/multiply_test.cc
index 1700399..cb2d15a 100644
--- a/multiply_test.cc
+++ b/multiply_test.cc
@@ -21,7 +21,7 @@
namespace intgemm {
// Rearrange a tile of simd x unroll entries.
-template <class V> void SlowRearrangeTile(const V *from, V *to, int simd, int unroll, int cols) {
+template <class V> void SlowRearrangeTile(const V *from, V *to, int simd, int unroll, Index cols) {
for (int i = 0; i < unroll; ++i) {
for (int j = 0; j < simd; ++j) {
to[simd * i + j] = from[cols * j + i];
@@ -29,7 +29,7 @@ template <class V> void SlowRearrangeTile(const V *from, V *to, int simd, int un
}
}
-template <class V> void SlowRearrange(const V *from, V *to, int simd, int unroll, int rows, int cols) {
+template <class V> void SlowRearrange(const V *from, V *to, int simd, int unroll, Index rows, Index cols) {
for (int c = 0; c < cols; c += unroll) {
for (int r = 0; r < rows; r += simd) {
SlowRearrangeTile(from + cols * r + c, to, simd, unroll, cols);
@@ -38,7 +38,7 @@ template <class V> void SlowRearrange(const V *from, V *to, int simd, int unroll
}
}
-template <class V> void SlowTranspose(const V *from, V *to, int rows, int cols) {
+template <class V> void SlowTranspose(const V *from, V *to, Index rows, Index cols) {
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
to[rows * c + r] = from[cols * r + c];
@@ -84,7 +84,7 @@ void TestTranspose8() {
}
}
-template <class T> void PrintMatrix(const T *mem, int rows, int cols) {
+template <class T> void PrintMatrix(const T *mem, Index rows, Index cols) {
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
std::cout << std::setw(4) << (int64_t) mem[r * cols + c] << ' ';
@@ -93,7 +93,7 @@ template <class T> void PrintMatrix(const T *mem, int rows, int cols) {
}
}
-template <class Routine> void TestPrepare(int rows = 32, int cols = 16) {
+template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) {
if (intgemm::kCPU < Routine::kUses) return;
// Create array.
AlignedVector<float> input(rows * cols);
@@ -125,7 +125,7 @@ template <class Routine> void TestPrepare(int rows = 32, int cols = 16) {
}
}
-template <class Routine> void TestSelectColumnsB(int rows = 32, int cols = 16) {
+template <class Routine> void TestSelectColumnsB(Index rows = 32, Index cols = 16) {
if (intgemm::kCPU < Routine::kUses) return;
AlignedVector<float> input(rows * cols);
for (int i = 0; i < rows * cols; ++i) {
@@ -136,7 +136,7 @@ template <class Routine> void TestSelectColumnsB(int rows = 32, int cols = 16) {
Routine::PrepareB(input.get(), prepared.get(), 1, rows, cols);
int kSelectCols = 24;
- std::size_t select_cols[kSelectCols];
+ Index select_cols[kSelectCols];
for (int i = 0; i < kSelectCols; ++i) {
select_cols[i] = rand() % cols;
}
@@ -218,7 +218,7 @@ template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute(
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
// Compute A*B slowly in floats.
-void SlowRefFloat(const float *A, const float *B, float *C, int A_rows, int width, int B_cols) {
+void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index width, Index B_cols) {
for (int r = 0; r < A_rows; ++r) {
for (int c = 0; c < B_cols; ++c) {
float sum = 0.0f;
@@ -231,7 +231,7 @@ void SlowRefFloat(const float *A, const float *B, float *C, int A_rows, int widt
}
// Compute A*B slowly from integers.
-template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
for (int r = 0; r < A_rows; ++r) {
for (int c = 0; c < B_cols; ++c) {
int32_t sum = 0;
@@ -258,7 +258,7 @@ void Compare(const float *float_ref, const float *int_ref, const float *int_test
std::cout << "Float MSE = " << sqrt(float_sum / size) << "\tInt MSE = " << sqrt(int_sum / size) << std::endl;
}
-template <class Routine> void TestMultiply(int A_rows, int width, int B_cols) {
+template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_cols) {
typedef typename Routine::Integer Integer;
if (intgemm::kCPU < Routine::kUses) return;
std::cout << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
@@ -295,7 +295,7 @@ template <class Routine> void TestMultiply(int A_rows, int width, int B_cols) {
Compare(float_C.get(), slowint_C.get(), test_C.get(), A_rows * B_cols);
}
-void TestBoth(int A_rows, int width, int B_cols) {
+void TestBoth(Index A_rows, Index width, Index B_cols) {
#ifndef INTGEMM_NO_AVX512
TestMultiply<AVX512_16bit>(A_rows, width, B_cols);
#endif
diff --git a/sse2_gemm.cc b/sse2_gemm.cc
index 9493775..4299d39 100644
--- a/sse2_gemm.cc
+++ b/sse2_gemm.cc
@@ -44,7 +44,7 @@ class QuantizeTile16 {
* This code: 0.00228409, 0.00204906
* With _mm_cvtps_pi16 basis: 0.00391884, 0.00390869
*/
-void SSE2_16bit::Quantize(const float *input, int16_t *output, float quant_mult, int size) {
+void SSE2_16bit::Quantize(const float *input, int16_t *output, float quant_mult, Index size) {
assert(size % 8 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 16 == 0);
assert(reinterpret_cast<uintptr_t>(output) % 16 == 0);
@@ -55,15 +55,15 @@ void SSE2_16bit::Quantize(const float *input, int16_t *output, float quant_mult,
}
}
-void SSE2_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols) {
+void SSE2_16bit::PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor16(input, output, QuantizeTile16(quant_mult), rows, cols);
}
-void SSE2_16bit::SelectColumnsB(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) {
+void SSE2_16bit::SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end);
}
-void SSE2_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+void SSE2_16bit::Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
Multiply16<__m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols);
}
diff --git a/sse2_gemm.h b/sse2_gemm.h
index da75915..0f26362 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -1,5 +1,5 @@
#pragma once
-#include "cpu_type.h"
+#include "types.h"
#include <cstdint>
#include <stdint.h>
// 8 bit is in ssse3_gemm.h
@@ -11,21 +11,21 @@ struct SSE2_16bit {
typedef int16_t Integer;
// Currently A is prepared by quantization but this could theoretically change.
- static inline void PrepareA(const float *input, int16_t *output, float quant_mult, int rows, int cols) {
+ static inline void PrepareA(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
Quantize(input, output, quant_mult, rows * cols);
}
- static void Quantize(const float *input, int16_t *output, float quant_mult, int size);
+ static void Quantize(const float *input, int16_t *output, float quant_mult, Index size);
// Tile size for B; B must be a multiple of this block size.
- static const int kBTileRow = 8;
- static const int kBTileCol = 8;
+ static const Index kBTileRow = 8;
+ static const Index kBTileCol = 8;
- static void PrepareB(const float *input, int16_t *output, float quant_mult, int rows, int cols);
+ static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols);
- static void SelectColumnsB(const int16_t *input, int16_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end);
+ static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
- static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
+ static void Multiply(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols);
static const char *const kName;
diff --git a/ssse3_gemm.cc b/ssse3_gemm.cc
index 17c0e8f..d5de13d 100644
--- a/ssse3_gemm.cc
+++ b/ssse3_gemm.cc
@@ -22,7 +22,7 @@ class QuantizeTile8 {
explicit QuantizeTile8(float mult) : mult_reg_(_mm_set1_ps(mult)) {}
- inline __m128i ForReshape(const float *input, int cols) {
+ inline __m128i ForReshape(const float *input, Index cols) {
// Skip a row.
return Tile(input, input + 2 * cols);
}
@@ -61,7 +61,7 @@ class QuantizeTile8 {
} // namespace
-void SSSE3_8bit::Quantize(const float *input, int8_t *output, float quant_mult, int size) {
+void SSSE3_8bit::Quantize(const float *input, int8_t *output, float quant_mult, Index size) {
assert(size % 16 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 16 == 0);
assert(reinterpret_cast<uintptr_t>(output) % 16 == 0);
@@ -72,15 +72,15 @@ void SSSE3_8bit::Quantize(const float *input, int8_t *output, float quant_mult,
}
}
-void SSSE3_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
+void SSSE3_8bit::PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
PrepareBFor8(input, output, QuantizeTile8(quant_mult), rows, cols);
}
-void SSSE3_8bit::SelectColumnsB(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end) {
+void SSSE3_8bit::SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
}
-void SSSE3_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols) {
+void SSSE3_8bit::Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols) {
Multiply8_SSE2OrAVX2<Multiply8_C, __m128i, __m128>(A, B, C, unquant_mult, A_rows, width, B_cols);
}
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index 07bfb5e..4993ef6 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -1,5 +1,5 @@
#pragma once
-#include "cpu_type.h"
+#include "types.h"
#include <cstdint>
#include <stdint.h>
@@ -12,21 +12,21 @@ struct SSSE3_8bit {
typedef int8_t Integer;
// Currently A is prepared by quantization but this could theoretically change.
- static inline void PrepareA(const float *input, int8_t *output, float quant_mult, int rows, int cols) {
+ static inline void PrepareA(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) {
Quantize(input, output, quant_mult, rows * cols);
}
- static void Quantize(const float *input, int8_t *output, float quant_mult, int size);
+ static void Quantize(const float *input, int8_t *output, float quant_mult, Index size);
// Tile size for B; B must be a multiple of this block size.
- static const int kBTileRow = 16;
- static const int kBTileCol = 8;
+ static const Index kBTileRow = 16;
+ static const Index kBTileCol = 8;
- static void PrepareB(const float *input, int8_t *output, float quant_mult, int rows, int cols);
+ static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols);
- static void SelectColumnsB(const int8_t *input, int8_t *output, int rows, const std::size_t *cols_begin, const std::size_t *cols_end);
+ static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end);
- static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, int A_rows, int width, int B_cols);
+ static void Multiply(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols);
static const char *const kName;
diff --git a/cpu_type.h b/types.h
index 58c6584..0954dd3 100644
--- a/cpu_type.h
+++ b/types.h
@@ -2,6 +2,8 @@
namespace intgemm {
+typedef unsigned int Index;
+
// If you want to detect the CPU and dispatch yourself, here's what to use:
typedef enum {CPU_AVX512BW = 4, CPU_AVX2 = 3, CPU_SSSE3 = 2, CPU_SSE2 = 1, CPU_UNSUPPORTED} CPUType;