Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-10-30 07:08:53 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-10-30 07:08:53 +0300
commit62b546436fc07035802eb998f61702ee2716db60 (patch)
tree5945a43e23957c417419c90f98759b76ec1a0282
parent61fb3b16894c8fff523efb4255247d151ed5bad5 (diff)
Speed up general case for float matrix multiply
-rw-r--r--dnn/vec_avx.h105
1 files changed, 46 insertions, 59 deletions
diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h
index b41f9862..767d7e19 100644
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -666,67 +666,54 @@ static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a,
#error "No optimizations in vec_avx.h. This should never happen. "
#endif
-static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
-{
- int i, j;
- for (i=0;i<rows;i+=16)
- {
- float *y;
- __m256 vy0, vy8;
- y = &out[i];
- vy0 = _mm256_setzero_ps();
- vy8 = _mm256_setzero_ps();
- for (j=0;j<cols;j++)
- {
- __m256 vxj;
- __m256 vw;
- vxj = _mm256_broadcast_ss(&x[j]);
-
- vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
- vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
-
- vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
- vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
- }
- _mm256_storeu_ps (&y[0], vy0);
- _mm256_storeu_ps (&y[8], vy8);
- }
-}
-
-static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
-{
- int i, j;
- for (i=0;i<rows;i+=8)
- {
- float *y;
- __m256 vy0;
- y = &out[i];
- vy0 = _mm256_setzero_ps();
- for (j=0;j<cols;j++)
- {
- __m256 vxj;
- __m256 vw;
- vxj = _mm256_broadcast_ss(&x[j]);
-
- vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
- vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
- }
- _mm256_storeu_ps (&y[0], vy0);
- }
-}
-
static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
{
- if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
- else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
- else {
- int i, j;
- for (i=0;i<rows;i++)
- {
- out[i] = 0;
- for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
- }
- }
+ int i, j;
+ i=0;
+ for (;i<rows-15;i+=16)
+ {
+ float *y;
+ __m256 vy0, vy8;
+ y = &out[i];
+ vy0 = _mm256_setzero_ps();
+ vy8 = _mm256_setzero_ps();
+ for (j=0;j<cols;j++)
+ {
+ __m256 vxj;
+ __m256 vw;
+ vxj = _mm256_broadcast_ss(&x[j]);
+
+ vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
+ vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+
+ vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
+ vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
+ }
+ _mm256_storeu_ps (&y[0], vy0);
+ _mm256_storeu_ps (&y[8], vy8);
+ }
+ for (;i<rows-7;i+=8)
+ {
+ float *y;
+ __m256 vy0;
+ y = &out[i];
+ vy0 = _mm256_setzero_ps();
+ for (j=0;j<cols;j++)
+ {
+ __m256 vxj;
+ __m256 vw;
+ vxj = _mm256_broadcast_ss(&x[j]);
+
+ vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
+ vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+ }
+ _mm256_storeu_ps (&y[0], vy0);
+ }
+ for (;i<rows;i++)
+ {
+ out[i] = 0;
+ for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
+ }
}
static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)