diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-10-30 07:08:53 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-10-30 07:08:53 +0300 |
commit | 62b546436fc07035802eb998f61702ee2716db60 (patch) | |
tree | 5945a43e23957c417419c90f98759b76ec1a0282 | |
parent | 61fb3b16894c8fff523efb4255247d151ed5bad5 (diff) |
Speed up general case for float matrix multiply
-rw-r--r-- | dnn/vec_avx.h | 105 |
1 files changed, 46 insertions, 59 deletions
diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index b41f9862..767d7e19 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -666,67 +666,54 @@ static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, #error "No optimizations in vec_avx.h. This should never happen. " #endif -static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) -{ - int i, j; - for (i=0;i<rows;i+=16) - { - float *y; - __m256 vy0, vy8; - y = &out[i]; - vy0 = _mm256_setzero_ps(); - vy8 = _mm256_setzero_ps(); - for (j=0;j<cols;j++) - { - __m256 vxj; - __m256 vw; - vxj = _mm256_broadcast_ss(&x[j]); - - vw = _mm256_loadu_ps(&weights[j*col_stride + i]); - vy0 = _mm256_fmadd_ps(vw, vxj, vy0); - - vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]); - vy8 = _mm256_fmadd_ps(vw, vxj, vy8); - } - _mm256_storeu_ps (&y[0], vy0); - _mm256_storeu_ps (&y[8], vy8); - } -} - -static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) -{ - int i, j; - for (i=0;i<rows;i+=8) - { - float *y; - __m256 vy0; - y = &out[i]; - vy0 = _mm256_setzero_ps(); - for (j=0;j<cols;j++) - { - __m256 vxj; - __m256 vw; - vxj = _mm256_broadcast_ss(&x[j]); - - vw = _mm256_loadu_ps(&weights[j*col_stride + i]); - vy0 = _mm256_fmadd_ps(vw, vxj, vy0); - } - _mm256_storeu_ps (&y[0], vy0); - } -} - static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) { - if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x); - else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x); - else { - int i, j; - for (i=0;i<rows;i++) - { - out[i] = 0; - for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; - } - } + int i, j; + i=0; + for (;i<rows-15;i+=16) + { + float *y; + __m256 vy0, vy8; + y = &out[i]; + vy0 = _mm256_setzero_ps(); + vy8 = _mm256_setzero_ps(); + for (j=0;j<cols;j++) + { + __m256 vxj; + __m256 vw; + vxj = _mm256_broadcast_ss(&x[j]); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]); + vy8 = _mm256_fmadd_ps(vw, vxj, vy8); + } + _mm256_storeu_ps (&y[0], vy0); + _mm256_storeu_ps (&y[8], vy8); + } + for (;i<rows-7;i+=8) + { + float *y; + __m256 vy0; + y = &out[i]; + vy0 = _mm256_setzero_ps(); + for (j=0;j<cols;j++) + { + __m256 vxj; + __m256 vw; + vxj = _mm256_broadcast_ss(&x[j]); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + } + _mm256_storeu_ps (&y[0], vy0); + } + for (;i<rows;i++) + { + out[i] = 0; + for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; + } } static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x) |