Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@amazon.com>2023-07-31 10:03:37 +0300
committerJean-Marc Valin <jmvalin@amazon.com>2023-08-02 02:16:27 +0300
commite9f8402a7122ca03e894d161c50706053bf4fb83 (patch)
treed33df933025cbccb577bb53545cf6fea554a289f
parent5eaa4a504f865e73c0e480fb95113e67f9310ffa (diff)
Handle float matrices with multiple of 8 rows
-rw-r--r--dnn/nnet.c2
-rw-r--r--dnn/vec.h40
-rw-r--r--dnn/vec_avx.h36
-rw-r--r--dnn/vec_neon.h49
4 files changed, 126 insertions, 1 deletions
diff --git a/dnn/nnet.c b/dnn/nnet.c
index 1c0035d0..05b0ea90 100644
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -78,7 +78,7 @@ void compute_linear(const LinearLayer *linear, float *out, const float *in)
N = linear->nb_outputs;
if (linear->float_weights != NULL) {
if (linear->weights_idx != NULL) sparse_sgemv8x4(out, linear->float_weights, linear->weights_idx, N, in);
- else sgemv16x1(out, linear->float_weights, N, M, N, in);
+ else sgemv(out, linear->float_weights, N, M, N, in);
} else if (linear->weights != NULL) {
if (linear->weights_idx != NULL) sparse_cgemv8x4(out, linear->weights, linear->weights_idx, linear->scale, N, M, in);
else cgemv8x4(out, linear->weights, linear->scale, N, M, in);
diff --git a/dnn/vec.h b/dnn/vec.h
index f6085cee..5b6951bb 100644
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -92,6 +92,46 @@ static inline void sgemv16x1(float *out, const float *weights, int rows, int col
}
}
+static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+ int i, j;
+ OPUS_CLEAR(out, rows);
+ for (i=0;i<rows;i+=8)
+ {
+ for (j=0;j<cols;j++)
+ {
+ const float * restrict w;
+ float * restrict y;
+ float xj;
+ w = &weights[j*col_stride + i];
+ xj = x[j];
+ y = &out[i];
+ y[0] += w[0]*xj;
+ y[1] += w[1]*xj;
+ y[2] += w[2]*xj;
+ y[3] += w[3]*xj;
+ y[4] += w[4]*xj;
+ y[5] += w[5]*xj;
+ y[6] += w[6]*xj;
+ y[7] += w[7]*xj;
+ }
+ }
+}
+
+static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+ if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
+ else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
+ else {
+ int i, j;
+ for (i=0;i<rows;i++)
+ {
+ out[i] = 0;
+ for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
+ }
+ }
+}
+
static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
{
int i, j;
diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h
index 77b3a0e0..4747bb41 100644
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -701,6 +701,42 @@ static inline void sgemv16x1(float *out, const float *weights, int rows, int col
}
}
+static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+ int i, j;
+ for (i=0;i<rows;i+=8)
+ {
+ float *y;
+ __m256 vy0;
+ y = &out[i];
+ vy0 = _mm256_setzero_ps();
+ for (j=0;j<cols;j++)
+ {
+ __m256 vxj;
+ __m256 vw;
+ vxj = _mm256_broadcast_ss(&x[j]);
+
+ vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
+ vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+ }
+ _mm256_storeu_ps (&y[0], vy0);
+ }
+}
+
+static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+ if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
+ else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
+ else {
+ int i, j;
+ for (i=0;i<rows;i++)
+ {
+ out[i] = 0;
+ for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
+ }
+ }
+}
+
static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)
{
int i, j;
diff --git a/dnn/vec_neon.h b/dnn/vec_neon.h
index 38c20d7b..48e3eaa1 100644
--- a/dnn/vec_neon.h
+++ b/dnn/vec_neon.h
@@ -239,6 +239,55 @@ static inline void sgemv16x1(float *out, const float *weights, int rows, int col
}
}
+static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+ int i, j;
+ for (i=0;i<rows;i+=8)
+ {
+ float * restrict y = &out[i];
+
+ /* keep y[0..15] in registers for duration of inner loop */
+
+ float32x4_t y0_3 = vdupq_n_f32(0);
+ float32x4_t y4_7 = vdupq_n_f32(0);
+
+ for (j=0;j<cols;j++)
+ {
+ const float * restrict w;
+ float32x4_t wvec0_3, wvec4_7;
+ float32x4_t xj;
+
+ w = &weights[j*col_stride + i];
+ wvec0_3 = vld1q_f32(&w[0]);
+ wvec4_7 = vld1q_f32(&w[4]);
+
+ xj = vld1q_dup_f32(&x[j]);
+
+ y0_3 = vmlaq_f32(y0_3, wvec0_3, xj);
+ y4_7 = vmlaq_f32(y4_7, wvec4_7, xj);
+ }
+
+ /* save y[0..15] back to memory */
+
+ vst1q_f32(&y[0], y0_3);
+ vst1q_f32(&y[4], y4_7);
+ }
+}
+
+static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+ if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
+ else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
+ else {
+ int i, j;
+ for (i=0;i<rows;i++)
+ {
+ out[i] = 0;
+ for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
+ }
+ }
+}
+
/* Temporarily use unoptimized version */
static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
{