Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-03-03 18:11:46 +0300
committerKenneth Heafield <github@kheafield.com>2020-03-03 18:11:46 +0300
commit0654d2785d70d31c65b072814acf7bd9a4cc5164 (patch)
tree6560374a65f3f2109b108b00b75aa59a51b52c64
parent8df31c7fa36bc0deabae0f49dd6a345c2e2ab703 (diff)
Compiler warnings
-rw-r--r--CMakeLists.txt6
-rw-r--r--avx2_gemm.h2
-rw-r--r--avx512_gemm.h2
-rw-r--r--benchmarks/benchmark.cc2
-rw-r--r--callbacks/implementations.inl2
-rw-r--r--intgemm.cc2
-rw-r--r--kernels/implementations.inl22
-rw-r--r--ssse3_gemm.h2
-rw-r--r--test/kernels/add_bias_test.cc2
-rw-r--r--test/kernels/bitwise_not_test.cc2
-rw-r--r--test/kernels/downcast_test.cc6
-rw-r--r--test/kernels/exp_test.cc2
-rw-r--r--test/kernels/floor_test.cc2
-rw-r--r--test/kernels/multiply_sat_test.cc4
-rw-r--r--test/kernels/multiply_test.cc2
-rw-r--r--test/kernels/quantize_test.cc2
-rw-r--r--test/kernels/relu_test.cc2
-rw-r--r--test/kernels/rescale_test.cc2
-rw-r--r--test/kernels/sigmoid_test.cc2
-rw-r--r--test/kernels/tanh_test.cc2
-rw-r--r--test/kernels/unquantize_test.cc2
-rw-r--r--test/kernels/upcast_test.cc6
-rw-r--r--test/kernels/write_test.cc2
-rw-r--r--test/multiply_test.cc14
-rw-r--r--test/test.h2
25 files changed, 55 insertions, 41 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 022fa7f..f29ccff 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -10,6 +10,12 @@ endif()
set(CMAKE_CXX_STANDARD 11)
+if(MSVC)
+ add_compile_options(/W4 /WX)
+else()
+ add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-unknown-pragmas)
+endif()
+
# Check if compiler supports AVX512
try_compile(INTGEMM_COMPILER_SUPPORTS_AVX512
${CMAKE_CURRENT_BINARY_DIR}/compile_tests
diff --git a/avx2_gemm.h b/avx2_gemm.h
index 54f3c6c..335fa0d 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -118,7 +118,7 @@ class QuantizeTile8 {
INTGEMM_AVX2 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const {
const float* inputs[4];
- for (int i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) {
+ for (Index i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) {
while (cols_left < sizeof(Register) / sizeof(float)) {
input += cols * (row_step - 1);
cols_left += cols;
diff --git a/avx512_gemm.h b/avx512_gemm.h
index eba0322..b3499af 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -108,7 +108,7 @@ class QuantizeTile8 {
static const __m512i shuffle_param = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
const float* inputs[4];
- for (int i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) {
+ for (Index i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) {
while (cols_left < sizeof(Register) / sizeof(float)) {
input += cols * (row_step - 1);
cols_left += cols;
diff --git a/benchmarks/benchmark.cc b/benchmarks/benchmark.cc
index 9f4fad3..26b0ac5 100644
--- a/benchmarks/benchmark.cc
+++ b/benchmarks/benchmark.cc
@@ -137,7 +137,7 @@ template <class Backend> void Print(std::vector<std::vector<uint64_t>> &stats, i
} // namespace
// Program takes no input
-int main(int argc, char ** argv) {
+int main(int, char ** argv) {
std::cerr << "Remember to run this on a specific core:\ntaskset --cpu-list 0 " << argv[0] << std::endl;
using namespace intgemm;
diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl
index dce89b2..25f8aa3 100644
--- a/callbacks/implementations.inl
+++ b/callbacks/implementations.inl
@@ -110,7 +110,7 @@ public:
unquant_mult = set1_ps<vf>(config.unquant_mult);
}
- CPU_ATTR vf operator()(vi input, const OutputBufferInfo& info) {
+ CPU_ATTR vf operator()(vi input, const OutputBufferInfo&) {
return kernels::unquantize(input, unquant_mult);
}
diff --git a/intgemm.cc b/intgemm.cc
index 43b8ca6..c069424 100644
--- a/intgemm.cc
+++ b/intgemm.cc
@@ -2,7 +2,7 @@
namespace intgemm {
-float Unsupported_MaxAbsolute(const float *begin, const float *end) {
+float Unsupported_MaxAbsolute(const float * /*begin*/, const float * /*end*/) {
throw UnsupportedCPU();
}
diff --git a/kernels/implementations.inl b/kernels/implementations.inl
index 80347fc..0d3e5d9 100644
--- a/kernels/implementations.inl
+++ b/kernels/implementations.inl
@@ -231,7 +231,7 @@ CPU_ATTR static inline dvector_t<CPUType::CPU_NAME, int16_t> upcast8to16(vi inpu
input = _mm256_permute4x64_epi64(input, 0xd8 /* = 0 2 1 3 */);
auto higher_byte = _mm256_cmpgt_epi8(vzero, input);
#else
- static const auto vmax_negative = set1_epi8<vi>(0xff);
+ static const auto vmax_negative = set1_epi8<vi>(-1 /* 0xff */);
static const auto permutation_indices = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0);
input = _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(input)));
@@ -254,7 +254,7 @@ CPU_ATTR static inline dvector_t<CPUType::CPU_NAME, int> upcast16to32(vi input)
input = _mm256_permute4x64_epi64(input, 0xd8 /* = 0 2 1 3 */);
auto higher_byte = _mm256_cmpgt_epi16(vzero, input);
#else
- static const auto vmax_negative = set1_epi16<vi>(0xffff);
+ static const auto vmax_negative = set1_epi16<vi>(-1 /* 0xffff */);
static const auto permutation_indices = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0);
input = _mm512_castpd_si512(_mm512_permutexvar_pd(permutation_indices, _mm512_castsi512_pd(input)));
@@ -354,10 +354,12 @@ CPU_ATTR static inline vf floor(vf input) {
/*
* Calculate approximation of e^x using Taylor series and lookup table
*/
-CPU_ATTR static inline vf exp_approx_taylor(vf x) {
#if defined(KERNELS_THIS_IS_SSE2)
+CPU_ATTR static inline vf exp_approx_taylor(vf) {
std::abort();
+}
#else
+CPU_ATTR static inline vf exp_approx_taylor(vf x) {
static constexpr int EXP_MIN = -20;
static constexpr int EXP_MAX = 20;
static constexpr float EXP_LOOKUP[EXP_MAX - EXP_MIN + 1] = {
@@ -408,13 +410,17 @@ CPU_ATTR static inline vf exp_approx_taylor(vf x) {
auto ea = i32gather_ps<4>(EXP_LOOKUP + EXP_MAX, cvtps_epi32(a));
return mul_ps(ea, result);
-#endif
}
+#endif
/*
* Sigmoid
*/
-CPU_ATTR static inline vf sigmoid(vf input) {
+CPU_ATTR static inline vf sigmoid(vf
+#ifndef KERNELS_THIS_IS_SSE2
+ input
+#endif
+ ) {
#if defined(KERNELS_THIS_IS_SSE2)
std::abort(); // TODO: missing exp_approx_taylor for SSE2
#elif defined(KERNELS_THIS_IS_AVX2)
@@ -451,18 +457,20 @@ CPU_ATTR static inline vf sigmoid(vf input) {
/*
* Tanh
*/
-CPU_ATTR static inline vf tanh(vf input) {
#if defined(KERNELS_THIS_IS_SSE2)
+CPU_ATTR static inline vf tanh(vf) {
std::abort(); // TODO: missing exp_approx_taylor for SSE2
+}
#else
+CPU_ATTR static inline vf tanh(vf input) {
const static auto vconst_zero = setzero_ps<vf>();
auto e_x = exp_approx_taylor(input);
auto e_minus_x = exp_approx_taylor(sub_ps(vconst_zero, input));
return div_ps(sub_ps(e_x, e_minus_x), add_ps(e_x, e_minus_x));
-#endif
}
+#endif
}
}
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index dc1ae83..2cf341e 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -42,7 +42,7 @@ class QuantizeTile8 {
INTGEMM_SSSE3 Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const {
const float* inputs[4];
- for (int i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) {
+ for (Index i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) {
while (cols_left < sizeof(Register) / sizeof(float)) {
input += cols * (row_step - 1);
cols_left += cols;
diff --git a/test/kernels/add_bias_test.cc b/test/kernels/add_bias_test.cc
index 4a2060e..7b10b56 100644
--- a/test/kernels/add_bias_test.cc
+++ b/test/kernels/add_bias_test.cc
@@ -22,7 +22,7 @@ void kernel_add_bias_test() {
std::fill(bias.begin(), bias.end(), 100);
*output.template as<vec_t>() = kernels::add_bias(*input.template as<vec_t>(), bias.begin(), 0);
- for (auto i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == ElemType_(100 + i));
}
diff --git a/test/kernels/bitwise_not_test.cc b/test/kernels/bitwise_not_test.cc
index 889e1bb..02b700b 100644
--- a/test/kernels/bitwise_not_test.cc
+++ b/test/kernels/bitwise_not_test.cc
@@ -20,7 +20,7 @@ void kernel_bitwise_not_test() {
std::iota(input.begin(), input.end(), 0);
*output.template as<vec_t>() = kernels::bitwise_not(*input.template as<vec_t>());
- for (auto i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == ~input[i]);
}
diff --git a/test/kernels/downcast_test.cc b/test/kernels/downcast_test.cc
index b25889f..5ecc084 100644
--- a/test/kernels/downcast_test.cc
+++ b/test/kernels/downcast_test.cc
@@ -22,7 +22,7 @@ void kernel_downcast32to8_test() {
*output.template as<vi>() = kernels::downcast32to8(
input.template as<vi>()[0], input.template as<vi>()[1],
input.template as<vi>()[2], input.template as<vi>()[3]);
- for (int i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == int8_t(input[i]));
}
@@ -52,7 +52,7 @@ void kernel_downcast32to16_test() {
*output.template as<vi>() = kernels::downcast32to16(
input.template as<vi>()[0], input.template as<vi>()[1]);
- for (int i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == int16_t(input[i]));
}
@@ -82,7 +82,7 @@ void kernel_downcast16to8_test() {
*output.template as<vi>() = kernels::downcast16to8(
input.template as<vi>()[0], input.template as<vi>()[1]);
- for (int i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == int8_t(input[i]));
}
diff --git a/test/kernels/exp_test.cc b/test/kernels/exp_test.cc
index d4e100e..d54f2ca 100644
--- a/test/kernels/exp_test.cc
+++ b/test/kernels/exp_test.cc
@@ -20,7 +20,7 @@ void kernel_exp_approx_taylor_test() {
std::iota(input.begin(), input.end(), -int(VECTOR_LENGTH / 2));
*output.template as<vec_t>() = kernels::exp_approx_taylor(*input.template as<vec_t>());
- for (auto i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK_EPS(output[i], exp(input[i]), 0.001f);
}
diff --git a/test/kernels/floor_test.cc b/test/kernels/floor_test.cc
index 3f4fdf3..10914a3 100644
--- a/test/kernels/floor_test.cc
+++ b/test/kernels/floor_test.cc
@@ -20,7 +20,7 @@ void kernel_floor_test() {
std::iota(input.begin(), input.end(), -int(VECTOR_LENGTH / 2));
*output.template as<vec_t>() = kernels::floor(*input.template as<vec_t>());
- for (auto i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == std::floor(input[i]));
}
diff --git a/test/kernels/multiply_sat_test.cc b/test/kernels/multiply_sat_test.cc
index 83ce5ac..87fc09a 100644
--- a/test/kernels/multiply_sat_test.cc
+++ b/test/kernels/multiply_sat_test.cc
@@ -21,9 +21,9 @@ void kernel_multiply_sat_test() {
std::iota(input1.begin(), input1.end(), -int(VECTOR_LENGTH / 2));
std::iota(input2.begin(), input2.end(), -int(VECTOR_LENGTH / 3));
- for (auto shift = 0; shift <= 2 * 8 * sizeof(Type_); ++shift) {
+ for (std::size_t shift = 0; shift <= 2 * 8 * sizeof(Type_); ++shift) {
*output.template as<vec_t>() = kernels::multiply_sat<Type_>(*input1.template as<vec_t>(), *input2.template as<vec_t>(), shift);
- for (auto i = 0; i < output.size(); ++i) {
+ for (std::size_t i = 0; i < output.size(); ++i) {
auto ref = (int64_t(input1[i]) * input2[i]) >> shift;
auto ref_sat = Type_(std::min<int64_t>(std::numeric_limits<Type_>::max(), std::max<int64_t>(std::numeric_limits<Type_>::min(), ref)));
CHECK(output[i] == ref_sat);
diff --git a/test/kernels/multiply_test.cc b/test/kernels/multiply_test.cc
index 90607f5..0eea965 100644
--- a/test/kernels/multiply_test.cc
+++ b/test/kernels/multiply_test.cc
@@ -22,7 +22,7 @@ void kernel_multiply_test() {
std::iota(input2.begin(), input2.end(), -int(VECTOR_LENGTH / 3));
*output.template as<vec_t>() = kernels::multiply<Type_>(*input1.template as<vec_t>(), *input2.template as<vec_t>());
- for (auto i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == Type_(input1[i] * input2[i]));
}
diff --git a/test/kernels/quantize_test.cc b/test/kernels/quantize_test.cc
index e666654..c9eae0a 100644
--- a/test/kernels/quantize_test.cc
+++ b/test/kernels/quantize_test.cc
@@ -21,7 +21,7 @@ void kernel_quantize_test() {
auto quant_mult = set1_ps<input_vec_t>(2.f);
*output.template as<output_vec_t>() = kernels::quantize(*input.template as<input_vec_t>(), quant_mult);
- for (auto i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == int(i*2.f));
}
diff --git a/test/kernels/relu_test.cc b/test/kernels/relu_test.cc
index fdf7c0e..25a212a 100644
--- a/test/kernels/relu_test.cc
+++ b/test/kernels/relu_test.cc
@@ -20,7 +20,7 @@ void kernel_relu_test() {
std::iota(input.begin(), input.end(), -int(VECTOR_LENGTH / 2));
*output.template as<vec_t>() = kernels::relu<ElemType_>(*input.template as<vec_t>());
- for (auto i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == (input[i] < 0 ? 0 : input[i]));
}
diff --git a/test/kernels/rescale_test.cc b/test/kernels/rescale_test.cc
index 1d7f556..d380c8d 100644
--- a/test/kernels/rescale_test.cc
+++ b/test/kernels/rescale_test.cc
@@ -22,7 +22,7 @@ void kernel_rescale_test() {
float scale = 2;
*output.template as<vi>() = kernels::rescale(*input.template as<vi>(), intgemm::set1_ps<vf>(scale));
- for (int i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == std::round(input[i] * scale));
}
diff --git a/test/kernels/sigmoid_test.cc b/test/kernels/sigmoid_test.cc
index e0e008e..e4743e2 100644
--- a/test/kernels/sigmoid_test.cc
+++ b/test/kernels/sigmoid_test.cc
@@ -27,7 +27,7 @@ void kernel_sigmoid_test() {
std::iota(input.begin(), input.end(), -int(VECTOR_LENGTH / 2));
*output.template as<vec_t>() = kernels::sigmoid(*input.template as<vec_t>());
- for (auto i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK_EPS(output[i], sigmoid_ref(input[i]), 0.001f);
}
diff --git a/test/kernels/tanh_test.cc b/test/kernels/tanh_test.cc
index 7407a11..737ac9b 100644
--- a/test/kernels/tanh_test.cc
+++ b/test/kernels/tanh_test.cc
@@ -20,7 +20,7 @@ void kernel_tanh_test() {
std::generate(input.begin(), input.end(), [] () { static int n = -int(VECTOR_LENGTH / 2); return n++ / float(VECTOR_LENGTH / 2); });
*output.template as<vec_t>() = kernels::tanh(*input.template as<vec_t>());
- for (auto i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK_EPS(output[i], tanh(input[i]), 0.001f);
}
diff --git a/test/kernels/unquantize_test.cc b/test/kernels/unquantize_test.cc
index 439970e..6f40da6 100644
--- a/test/kernels/unquantize_test.cc
+++ b/test/kernels/unquantize_test.cc
@@ -21,7 +21,7 @@ void kernel_unquantize_test() {
auto unquant_mult = set1_ps<output_vec_t>(0.5f);
*output.template as<output_vec_t>() = kernels::unquantize(*input.template as<input_vec_t>(), unquant_mult);
- for (auto i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == i * 0.5f);
}
diff --git a/test/kernels/upcast_test.cc b/test/kernels/upcast_test.cc
index 5c13dfd..df6a62e 100644
--- a/test/kernels/upcast_test.cc
+++ b/test/kernels/upcast_test.cc
@@ -23,7 +23,7 @@ void kernel_upcast8to16_test() {
output.template as<vi>()[0] = result.first;
output.template as<vi>()[1] = result.second;
- for (int i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == int16_t(input[i]));
}
@@ -55,7 +55,7 @@ void kernel_upcast16to32_test() {
output.template as<vi>()[0] = result.first;
output.template as<vi>()[1] = result.second;
- for (int i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == int32_t(input[i]));
}
@@ -89,7 +89,7 @@ void kernel_upcast8to32_test() {
output.template as<vi>()[2] = result.third;
output.template as<vi>()[3] = result.fourth;
- for (int i = 0; i < output.size(); ++i)
+ for (std::size_t i = 0; i < output.size(); ++i)
CHECK(output[i] == int32_t(input[i]));
}
diff --git a/test/kernels/write_test.cc b/test/kernels/write_test.cc
index 53a0ea6..aeaafcb 100644
--- a/test/kernels/write_test.cc
+++ b/test/kernels/write_test.cc
@@ -20,7 +20,7 @@ void kernel_write_test() {
std::iota(input.begin(), input.end(), 0);
kernels::write(*input.template as<vec_t>(), output.begin(), 0);
- for (auto i = 0; i < VECTOR_LENGTH; ++i)
+ for (std::size_t i = 0; i < VECTOR_LENGTH; ++i)
CHECK(output[i] == ElemType_(i));
}
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 97f68a3..7316bb0 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -31,7 +31,7 @@ INTGEMM_SSE2 TEST_CASE("Transpose 16", "[transpose]") {
__m128i *t = input.as<__m128i>();
Transpose16InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7]);
- for (int16_t i = 0; i < input.size(); ++i) {
+ for (std::size_t i = 0; i < input.size(); ++i) {
CHECK_MESSAGE(ref[i] == input[i], "16-bit transpose failure at: " << i << ": " << ref[i] << " != " << input[i]);
}
}
@@ -49,7 +49,7 @@ INTGEMM_SSSE3 TEST_CASE("Transpose 8", "[transpose]") {
__m128i *t = input.as<__m128i>();
Transpose8InLane(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8], t[9], t[10], t[11], t[12], t[13], t[14], t[15]);
- for (int i = 0; i < input.size(); ++i) {
+ for (std::size_t i = 0; i < input.size(); ++i) {
CHECK_MESSAGE(ref[i] == input[i], "8-bit transpose failure at " << i << ": " << (int16_t)ref[i] << " != " << (int16_t)input[i]);
}
}
@@ -121,7 +121,7 @@ template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 1
AlignedVector<Integer> prepared(input.size());
Routine::PrepareB(input.begin(), prepared.begin(), 1, rows, cols);
- int kSelectCols = 24;
+ const int kSelectCols = 24;
Index select_cols[kSelectCols];
std::uniform_int_distribution<Index> col_dist(0, cols - 1);
for (auto& it : select_cols) {
@@ -197,7 +197,7 @@ template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute(
const std::size_t kLengthMax = 65;
AlignedVector<float> test(kLengthMax);
for (std::size_t len = 1; len < kLengthMax; ++len) {
- for (int t = 0; t < len; ++t) {
+ for (std::size_t t = 0; t < len; ++t) {
// Fill with [-8, 8).
for (auto& it : test) {
it = dist(gen);
@@ -268,7 +268,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
for (auto& it : B) {
it = dist(gen);
}
-
+
float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
float unquant_mult = 1.0/(quant_mult*quant_mult);
@@ -288,12 +288,12 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());
AlignedVector<float> slowint_C(test_C.size());
// Assuming A is just quantization here.
- references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) {
return sum * unquant_mult;
});
AlignedVector<float> float_C(test_C.size());
- references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo&) {
return sum;
});
diff --git a/test/test.h b/test/test.h
index af6e17a..a1ca724 100644
--- a/test/test.h
+++ b/test/test.h
@@ -106,7 +106,7 @@ void Multiply(const TypeA* A, const TypeB* B, float* C, Index A_rows, Index widt
// Matrix rearragement
template <typename Type>
-void Rearragement(const Type* input, Type* output, int simd, int unroll, Index rows, Index cols) {
+void Rearragement(const Type* input, Type* output, Index simd, Index unroll, Index rows, Index cols) {
for (Index c = 0; c < cols; c += unroll) {
for (Index r = 0; r < rows; r += simd) {
for (Index i = 0; i < unroll; ++i)