diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-08-02 02:14:29 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-08-02 02:14:29 +0300 |
commit | 0a61ee08b0b8c0e820ef6966c23b3eea58abbfaa (patch) | |
tree | 45f589f35000674b21efe316c52fecda371afa81 | |
parent | eaee3d9d4e0276fcc1b722cc629f451eabdbb3ac (diff) |
Bring back support for non-multiples of 8exp-fwgan8
-rw-r--r-- | dnn/vec.h | 11 | ||||
-rw-r--r-- | dnn/vec_avx.h | 11 | ||||
-rw-r--r-- | dnn/vec_neon.h | 11 |
3 files changed, 27 insertions, 6 deletions
@@ -120,9 +120,16 @@ static inline void sgemv16x1(float *out, const float *weights, int rows, int col static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) { - celt_assert((rows&7) == 0); if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x); - else sgemv8x1(out, weights, rows, cols, col_stride, x); + else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x); + else { + int i, j; + for (i=0;i<rows;i++) + { + out[i] = 0; + for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; + } + } } static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x) diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index b31caa88..4747bb41 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -725,9 +725,16 @@ static inline void sgemv8x1(float *out, const float *weights, int rows, int cols static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) { - celt_assert((rows&7) == 0); if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x); - else sgemv8x1(out, weights, rows, cols, col_stride, x); + else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x); + else { + int i, j; + for (i=0;i<rows;i++) + { + out[i] = 0; + for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; + } + } } static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x) diff --git a/dnn/vec_neon.h b/dnn/vec_neon.h index af3b204f..48e3eaa1 100644 --- a/dnn/vec_neon.h +++ b/dnn/vec_neon.h @@ -276,9 +276,16 @@ static inline void sgemv8x1(float *out, const float *weights, int rows, int cols static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) { - celt_assert((rows&7) == 0); if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x); - else sgemv8x1(out, weights, rows, cols, col_stride, x); + else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x); + else { + int i, j; + for (i=0;i<rows;i++) + { + out[i] = 0; + for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; + } + } } /* Temporarily use unoptimized version */ |