From 7a74ff1c51207380f8bad35a44a7544bf296bdf9 Mon Sep 17 00:00:00 2001 From: Mateusz Chudyk Date: Mon, 20 Apr 2020 18:14:17 +0100 Subject: Fix XXXCustomTile functions --- intgemm.h | 84 ++++++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 48 insertions(+), 36 deletions(-) diff --git a/intgemm.h b/intgemm.h index 3f4e8c2..3639721 100644 --- a/intgemm.h +++ b/intgemm.h @@ -256,12 +256,14 @@ struct Int8 { // Warning: the output of PrepareB depends on the CPU. // It will match the Multiply function on the same CPU though. static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { - switch (cols % 32) { - case 0: PrepareBCustomTile<4>(input, output, quant_mult, rows, cols); break; - case 24: PrepareBCustomTile<3>(input, output, quant_mult, rows, cols); break; - case 16: PrepareBCustomTile<2>(input, output, quant_mult, rows, cols); break; - default: PrepareBCustomTile<1>(input, output, quant_mult, rows, cols); break; - } + if (cols % 32 == 0) + PrepareBCustomTile<4>(input, output, quant_mult, rows, cols); + else if (cols % 24 == 0) + PrepareBCustomTile<3>(input, output, quant_mult, rows, cols); + else if (cols % 16 == 0) + PrepareBCustomTile<2>(input, output, quant_mult, rows, cols); + else + PrepareBCustomTile<1>(input, output, quant_mult, rows, cols); } template @@ -286,12 +288,14 @@ struct Int8 { // Multiply C = A * B, presuming A and B have been prepared. template static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - switch (B_cols % 32) { - case 0: MultiplyCustomTile<1, 4, Callback>(A, B, A_rows, width, B_cols, callback); break; - case 24: MultiplyCustomTile<1, 3, Callback>(A, B, A_rows, width, B_cols, callback); break; - case 16: MultiplyCustomTile<1, 2, Callback>(A, B, A_rows, width, B_cols, callback); break; - default: MultiplyCustomTile<1, 1, Callback>(A, B, A_rows, width, B_cols, callback); break; - } + if (B_cols % 32 == 0) + MultiplyCustomTile<1, 4, Callback>(A, B, A_rows, width, B_cols, callback); + else if (B_cols % 24 == 0) + MultiplyCustomTile<1, 3, Callback>(A, B, A_rows, width, B_cols, callback); + else if (B_cols % 16 == 0) + MultiplyCustomTile<1, 2, Callback>(A, B, A_rows, width, B_cols, callback); + else + MultiplyCustomTile<1, 1, Callback>(A, B, A_rows, width, B_cols, callback); } template @@ -352,12 +356,14 @@ struct Int8Shift { // Warning: the output of PrepareB depends on the CPU. // It will match the Multiply function on the same CPU though. static void PrepareB(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) { - switch (cols % 32) { - case 0: PrepareBCustomTile<4>(input, output, quant_mult, rows, cols); break; - case 24: PrepareBCustomTile<3>(input, output, quant_mult, rows, cols); break; - case 16: PrepareBCustomTile<2>(input, output, quant_mult, rows, cols); break; - default: PrepareBCustomTile<1>(input, output, quant_mult, rows, cols); break; - } + if (cols % 32 == 0) + PrepareBCustomTile<4>(input, output, quant_mult, rows, cols); + else if (cols % 24 == 0) + PrepareBCustomTile<3>(input, output, quant_mult, rows, cols); + else if (cols % 16 == 0) + PrepareBCustomTile<2>(input, output, quant_mult, rows, cols); + else + PrepareBCustomTile<1>(input, output, quant_mult, rows, cols); } template @@ -374,12 +380,14 @@ struct Int8Shift { // Multiply C = A * B + Bias, presuming A, B and Bias have all been prepared (for A, PrepareAnew should be used template static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - switch (B_cols % 32) { - case 0: MultiplyCustomTile<1, 4, Callback>(A, B, A_rows, width, B_cols, callback); break; - case 24: MultiplyCustomTile<1, 3, Callback>(A, B, A_rows, width, B_cols, callback); break; - case 16: MultiplyCustomTile<1, 2, Callback>(A, B, A_rows, width, B_cols, callback); break; - default: MultiplyCustomTile<1, 1, Callback>(A, B, A_rows, width, B_cols, callback); break; - } + if (B_cols % 32 == 0) + MultiplyCustomTile<1, 4, Callback>(A, B, A_rows, width, B_cols, callback); + else if (B_cols % 24 == 0) + MultiplyCustomTile<1, 3, Callback>(A, B, A_rows, width, B_cols, callback); + else if (B_cols % 16 == 0) + MultiplyCustomTile<1, 2, Callback>(A, B, A_rows, width, B_cols, callback); + else + MultiplyCustomTile<1, 1, Callback>(A, B, A_rows, width, B_cols, callback); } template @@ -446,12 +454,14 @@ struct Int16 { // Warning: the output of PrepareB depends on the CPU. // It will match the Multiply function on the same CPU though. static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) { - switch (cols % 32) { - case 0: PrepareBCustomTile<4>(input, output, quant_mult, rows, cols); break; - case 24: PrepareBCustomTile<3>(input, output, quant_mult, rows, cols); break; - case 16: PrepareBCustomTile<2>(input, output, quant_mult, rows, cols); break; - default: PrepareBCustomTile<1>(input, output, quant_mult, rows, cols); break; - } + if (cols % 32 == 0) + PrepareBCustomTile<4>(input, output, quant_mult, rows, cols); + else if (cols % 24 == 0) + PrepareBCustomTile<3>(input, output, quant_mult, rows, cols); + else if (cols % 16 == 0) + PrepareBCustomTile<2>(input, output, quant_mult, rows, cols); + else + PrepareBCustomTile<1>(input, output, quant_mult, rows, cols); } template @@ -476,12 +486,14 @@ struct Int16 { // Multiply C = A * B, presuming A and B have been prepared. template static void Multiply(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - switch (B_cols % 32) { - case 0: MultiplyCustomTile<1, 4, Callback>(A, B, A_rows, width, B_cols, callback); break; - case 24: MultiplyCustomTile<1, 3, Callback>(A, B, A_rows, width, B_cols, callback); break; - case 16: MultiplyCustomTile<1, 2, Callback>(A, B, A_rows, width, B_cols, callback); break; - default: MultiplyCustomTile<1, 1, Callback>(A, B, A_rows, width, B_cols, callback); break; - } + if (B_cols % 32 == 0) + MultiplyCustomTile<1, 4, Callback>(A, B, A_rows, width, B_cols, callback); + else if (B_cols % 24 == 0) + MultiplyCustomTile<1, 3, Callback>(A, B, A_rows, width, B_cols, callback); + else if (B_cols % 16 == 0) + MultiplyCustomTile<1, 2, Callback>(A, B, A_rows, width, B_cols, callback); + else + MultiplyCustomTile<1, 1, Callback>(A, B, A_rows, width, B_cols, callback); } template -- cgit v1.2.3