diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2023-07-16 09:21:49 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2023-07-20 08:01:32 +0300 |
commit | f5a68a41b01d7edcd629ec9b284302df1518e922 (patch) | |
tree | ee1c7c20caa5d4a4a81bbd534827dcb1b25e8597 | |
parent | 8423ef1de25d8006be837504c22080e0319a1d26 (diff) |
Add generic linear layer
Should be able to handle all previous GRU variants and more.
-rw-r--r-- | dnn/nnet.c | 213 | ||||
-rw-r--r-- | dnn/nnet.h | 12 | ||||
-rw-r--r-- | dnn/vec.h | 224 | ||||
-rw-r--r-- | dnn/vec_avx.h | 204 | ||||
-rw-r--r-- | dnn/vec_neon.h | 73 |
5 files changed, 722 insertions, 4 deletions
@@ -85,6 +85,73 @@ static void sgemv_accum(float *out, const float *weights, int rows, int cols, in } } +void compute_linear(const LinearLayer *linear, float *out, const float *in) +{ + int i, M, N; + const float *bias; + bias = linear->bias; + M = linear->nb_inputs; + N = linear->nb_outputs; + if (linear->float_weights != NULL) { + if (linear->weights_idx != NULL) sparse_sgemv8x4(out, linear->float_weights, linear->weights_idx, N, in); + else sgemv16x1(out, linear->float_weights, N, M, N, in); + } else if (linear->weights != NULL) { + if (linear->weights_idx != NULL) sparse_cgemv8x4(out, linear->weights, linear->weights_idx, linear->scale, N, M, in); + else cgemv8x4(out, linear->weights, linear->scale, N, M, in); + /* Only use SU biases on for integer matrices on SU archs. */ +#ifdef USE_SU_BIAS + bias = linear->subias; +#endif + } + else OPUS_CLEAR(out, N); + if (bias != NULL) { + for (i=0;i<N;i++) out[i] += bias[i]; + } + if (linear->diag) { + /* Diag is only used for GRU recurrent weights. */ + celt_assert(3*M == N); + for (i=0;i<M;i++) { + out[i] += linear->diag[i]*in[i]; + out[i+M] += linear->diag[i+M]*in[i]; + out[i+2*M] += linear->diag[i+2*M]*in[i]; + } + } +} + +#define MAX_RNN_NEURONS_ALL IMAX(IMAX(MAX_RNN_NEURONS, PLC_MAX_RNN_NEURONS), DRED_MAX_RNN_NEURONS) + + +void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in) +{ + int i; + int N; + float zrh[3*MAX_RNN_NEURONS_ALL]; + float recur[3*MAX_RNN_NEURONS_ALL]; + float *z; + float *r; + float *h; + celt_assert(3*recurrent_weights->nb_inputs == recurrent_weights->nb_outputs); + celt_assert(input_weights->nb_outputs == recurrent_weights->nb_outputs); + N = recurrent_weights->nb_inputs; + z = zrh; + r = &zrh[N]; + h = &zrh[2*N]; + celt_assert(recurrent_weights->nb_outputs <= 3*MAX_RNN_NEURONS_ALL); + celt_assert(in != state); + compute_linear(input_weights, zrh, in); + compute_linear(recurrent_weights, recur, state); + for (i=0;i<2*N;i++) + zrh[i] += recur[i]; + compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID); + for (i=0;i<N;i++) + h[i] += recur[2*N+i]*r[i]; + compute_activation(h, h, N, ACTIVATION_TANH); + for (i=0;i<N;i++) + h[i] = z[i]*state[i] + (1-z[i])*h[i]; + for (i=0;i<N;i++) + state[i] = h[i]; +} + void compute_activation(float *output, const float *input, int N, int activation) { int i; @@ -119,6 +186,7 @@ void compute_activation(float *output, const float *input, int N, int activation } } +#if 1 void _lpcnet_compute_dense(const DenseLayer *layer, float *output, const float *input) { int i; @@ -133,7 +201,24 @@ void _lpcnet_compute_dense(const DenseLayer *layer, float *output, const float * sgemv_accum(output, layer->input_weights, N, M, stride, input); compute_activation(output, output, N, layer->activation); } - +#else +void _lpcnet_compute_dense(const DenseLayer *layer, float *output, const float *input) +{ + LinearLayer matrix; + celt_assert(input != output); + matrix.bias = layer->bias; + matrix.subias = NULL; + matrix.float_weights = layer->input_weights; + matrix.weights = NULL; + matrix.weights_idx = NULL; + matrix.diag = NULL; + matrix.nb_inputs = layer->nb_inputs; + matrix.nb_outputs = layer->nb_neurons; + matrix.scale = NULL; + compute_linear(&matrix, output, input); + compute_activation(output, output, layer->nb_neurons, layer->activation); +} +#endif int sample_mdense(const MDenseLayer *layer, const float *input, const float *sampling_logit_table, kiss99_ctx *rng) { @@ -188,9 +273,15 @@ int sample_mdense(const MDenseLayer *layer, const float *input, const float *sam } +#ifdef USE_SU_BIAS +#define bias_type subias +#else +#define bias_type bias +#endif +#define MAX_IDX_SIZE 8192 -#define MAX_RNN_NEURONS_ALL IMAX(IMAX(MAX_RNN_NEURONS, PLC_MAX_RNN_NEURONS), DRED_MAX_RNN_NEURONS) +#if 1 void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *state, const float *input) { int i; @@ -239,7 +330,59 @@ void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *stat state[i] = h[i]; } +#else + +void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *state, const float *input) +{ + LinearLayer in_matrix, rec_matrix; + int i, M, N; + float bias[3*MAX_RNN_NEURONS_ALL]; + float scale[3*MAX_RNN_NEURONS_ALL]; + M = gru->nb_inputs; + N = gru->nb_neurons; + + in_matrix.bias = bias; + in_matrix.diag = NULL; + in_matrix.nb_inputs = M; + in_matrix.nb_outputs = 3*N; + in_matrix.subias = bias; +#ifdef DISABLE_DOT_PROD + for (i=0;i<3*N;i++) bias[i] = gru->bias[i] + gru_b_condition[i]; + in_matrix.scale = NULL; + in_matrix.float_weights = gru->input_weights; + in_matrix.weights = NULL; +#else + for (i=0;i<3*N;i++) bias[i] = gru->bias_type[i] + gru_b_condition[i]; + for (i=0;i<3*N;i++) scale[i] = SCALE_1; + in_matrix.scale = scale; + in_matrix.weights = gru->input_weights; + in_matrix.float_weights = NULL; +#endif + in_matrix.weights_idx = gru->input_weights_idx; + + rec_matrix.bias = &gru->bias[3*N]; + rec_matrix.diag = NULL; + rec_matrix.nb_inputs = N; + rec_matrix.nb_outputs = 3*N; + rec_matrix.scale = scale; + rec_matrix.subias = &gru->subias[3*N]; +#ifdef DISABLE_DOT_PROD + rec_matrix.scale = NULL; + rec_matrix.float_weights = gru->recurrent_weights; + rec_matrix.weights = NULL; +#else + rec_matrix.scale = scale; + rec_matrix.weights = gru->recurrent_weights; + rec_matrix.float_weights = NULL; +#endif + rec_matrix.weights_idx = NULL; + compute_generic_gru(&in_matrix, &rec_matrix, state, input); +} +#endif + + +#if 1 /* The input of this GRU is after the input matrix multiply. */ void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input) { @@ -280,9 +423,49 @@ void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *in for (i=0;i<N;i++) state[i] = z[i]*state[i] + (1-z[i])*h[i]; } +#else +/* The input of this GRU is after the input matrix multiply. */ +void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input) +{ + LinearLayer in_matrix, rec_matrix; + int i, N; + float scale[3*MAX_RNN_NEURONS_ALL]; + N = gru->nb_neurons; + + in_matrix.bias = input; + in_matrix.diag = NULL; + in_matrix.nb_inputs = N; + in_matrix.nb_outputs = 3*N; + in_matrix.subias = input; + in_matrix.scale = NULL; + in_matrix.float_weights = NULL; + in_matrix.weights = NULL; + in_matrix.weights_idx = NULL; + + rec_matrix.bias = &gru->bias[3*N]; + rec_matrix.diag = gru->diag_weights; + rec_matrix.nb_inputs = N; + rec_matrix.nb_outputs = 3*N; + rec_matrix.subias = &gru->subias[3*N]; +#ifdef DISABLE_DOT_PROD + rec_matrix.scale = NULL; + rec_matrix.float_weights = gru->recurrent_weights; + rec_matrix.weights = NULL; +#else + for (i=0;i<3*N;i++) scale[i] = SCALE_1; + rec_matrix.scale = scale; + rec_matrix.weights = gru->recurrent_weights; + rec_matrix.float_weights = NULL; +#endif + rec_matrix.weights_idx = gru->idx; + compute_generic_gru(&in_matrix, &rec_matrix, state, input); +} +#endif + #define MAX_CONV_INPUTS_ALL IMAX(MAX_CONV_INPUTS, DRED_MAX_CONV_INPUTS) +#if 1 void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input) { int i; @@ -302,6 +485,32 @@ void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const f compute_activation(output, output, N, layer->activation); OPUS_COPY(mem, &tmp[layer->nb_inputs], layer->nb_inputs*(layer->kernel_size-1)); } +#else +void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input) +{ + LinearLayer matrix; + int N, M; + float tmp[MAX_CONV_INPUTS_ALL]; + celt_assert(input != output); + celt_assert(layer->nb_inputs*layer->kernel_size <= MAX_CONV_INPUTS_ALL); + OPUS_COPY(tmp, mem, layer->nb_inputs*(layer->kernel_size-1)); + OPUS_COPY(&tmp[layer->nb_inputs*(layer->kernel_size-1)], input, layer->nb_inputs); + M = layer->nb_inputs*layer->kernel_size; + N = layer->nb_neurons; + matrix.bias = layer->bias; + matrix.subias = NULL; + matrix.float_weights = layer->input_weights; + matrix.weights = NULL; + matrix.weights_idx = NULL; + matrix.diag = NULL; + matrix.nb_inputs = M; + matrix.nb_outputs = N; + matrix.scale = NULL; + compute_linear(&matrix, output, tmp); + compute_activation(output, output, N, layer->activation); + OPUS_COPY(mem, &tmp[layer->nb_inputs], layer->nb_inputs*(layer->kernel_size-1)); +} +#endif void compute_embedding(const EmbeddingLayer *layer, float *output, int input) { @@ -61,6 +61,18 @@ typedef struct { char name[44]; } WeightHead; +/* Generic sparse affine transformation. */ +typedef struct { + const float *bias; + const float *subias; + const opus_int8 *weights; + const float *float_weights; + const int *weights_idx; + const float *diag; + const float *scale; + int nb_inputs; + int nb_outputs; +} LinearLayer; typedef struct { const float *bias; @@ -56,6 +56,230 @@ typedef signed char qweight; typedef float qweight; #endif +static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) +{ + int i, j; + OPUS_CLEAR(out, rows); + for (i=0;i<rows;i+=16) + { + for (j=0;j<cols;j++) + { + const float * restrict w; + float * restrict y; + float xj; + w = &weights[j*col_stride + i]; + xj = x[j]; + y = &out[i]; + y[0] += w[0]*xj; + y[1] += w[1]*xj; + y[2] += w[2]*xj; + y[3] += w[3]*xj; + y[4] += w[4]*xj; + y[5] += w[5]*xj; + y[6] += w[6]*xj; + y[7] += w[7]*xj; + y[8] += w[8]*xj; + y[9] += w[9]*xj; + y[10] += w[10]*xj; + y[11] += w[11]*xj; + y[12] += w[12]*xj; + y[13] += w[13]*xj; + y[14] += w[14]*xj; + y[15] += w[15]*xj; + } + } +} + +static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x) +{ + int i, j; + OPUS_CLEAR(out, rows); + for (i=0;i<rows;i+=8) + { + int cols; + cols = *idx++; + for (j=0;j<cols;j++) + { + int pos; + float * restrict y; + float xj0, xj1, xj2, xj3; + pos = (*idx++); + xj0 = x[pos+0]; + xj1 = x[pos+1]; + xj2 = x[pos+2]; + xj3 = x[pos+3]; + y = &out[i]; + y[0] += w[0]*xj0; + y[1] += w[1]*xj0; + y[2] += w[2]*xj0; + y[3] += w[3]*xj0; + y[4] += w[4]*xj0; + y[5] += w[5]*xj0; + y[6] += w[6]*xj0; + y[7] += w[7]*xj0; + + y[0] += w[8]*xj1; + y[1] += w[9]*xj1; + y[2] += w[10]*xj1; + y[3] += w[11]*xj1; + y[4] += w[12]*xj1; + y[5] += w[13]*xj1; + y[6] += w[14]*xj1; + y[7] += w[15]*xj1; + + y[0] += w[16]*xj2; + y[1] += w[17]*xj2; + y[2] += w[18]*xj2; + y[3] += w[19]*xj2; + y[4] += w[20]*xj2; + y[5] += w[21]*xj2; + y[6] += w[22]*xj2; + y[7] += w[23]*xj2; + + y[0] += w[24]*xj3; + y[1] += w[25]*xj3; + y[2] += w[26]*xj3; + y[3] += w[27]*xj3; + y[4] += w[28]*xj3; + y[5] += w[29]*xj3; + y[6] += w[30]*xj3; + y[7] += w[31]*xj3; + w += 32; + } + } +} + +#ifdef USE_SU_BIAS +static inline void sparse_cgemv8x4(float *out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x) +{ + int i, j; + unsigned char x[MAX_INPUTS]; + for (i=0;i<rows;i++) out[i] = 0; + for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]); + for (i=0;i<rows;i+=8) + { + int colblocks; + colblocks = *idx++; + for (j=0;j<colblocks;j++) + { + int pos; + float * restrict y; + int xj0, xj1, xj2, xj3; + pos = (*idx++); + xj0 = x[pos+0]; + xj1 = x[pos+1]; + xj2 = x[pos+2]; + xj3 = x[pos+3]; + y = &out[i]; + y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3); + y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3); + y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3); + y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3); + y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3); + y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3); + y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3); + y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3); + w += 32; + } + } + for (i=0;i<rows;i++) out[i] *= scale[i]; +} +static inline void cgemv8x4(float *out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x) +{ + int i, j; + unsigned char x[MAX_INPUTS]; + for (i=0;i<rows;i++) out[i] = 0; + for (i=0;i<cols;i++) x[i] = 127+(int)floor(.5+127*_x[i]); + for (i=0;i<rows;i+=8) + { + for (j=0;j<cols;j+=4) + { + float *y; + float xj0, xj1, xj2, xj3; + xj0 = x[j+0]; + xj1 = x[j+1]; + xj2 = x[j+2]; + xj3 = x[j+3]; + y = &out[i]; + y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3); + y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3); + y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3); + y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3); + y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3); + y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3); + y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3); + y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3); + w += 32; + } + } + for (i=0;i<rows;i++) out[i] *= scale[i]; +} +#else +static inline void sparse_cgemv8x4(float *out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x) +{ + int i, j; + opus_int8 x[MAX_INPUTS]; + for (i=0;i<rows;i++) out[i] = 0; + for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]); + for (i=0;i<rows;i+=8) + { + int colblocks; + colblocks = *idx++; + for (j=0;j<colblocks;j++) + { + int pos; + float * restrict y; + int xj0, xj1, xj2, xj3; + pos = (*idx++); + xj0 = x[pos+0]; + xj1 = x[pos+1]; + xj2 = x[pos+2]; + xj3 = x[pos+3]; + y = &out[i]; + y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3); + y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3); + y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3); + y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3); + y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3); + y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3); + y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3); + y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3); + w += 32; + } + } + for (i=0;i<rows;i++) out[i] *= scale[i]; +} +static inline void cgemv8x4(float *out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x) +{ + int i, j; + opus_int8 x[MAX_INPUTS]; + for (i=0;i<rows;i++) out[i] = 0; + for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]); + for (i=0;i<rows;i+=8) + { + for (j=0;j<cols;j+=4) + { + float *y; + float xj0, xj1, xj2, xj3; + xj0 = x[j+0]; + xj1 = x[j+1]; + xj2 = x[j+2]; + xj3 = x[j+3]; + y = &out[i]; + y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3); + y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3); + y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3); + y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3); + y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3); + y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3); + y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3); + y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3); + w += 32; + } + } + for (i=0;i<rows;i++) out[i] *= scale[i]; +} +#endif /* No AVX2/FMA support */ #ifndef LPCNET_TEST diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index 733cf6a9..0f494b6d 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -35,6 +35,10 @@ #include <immintrin.h> #include <math.h> + +#define MAX_INPUTS (2048) + + /* Use 8-bit dot products unless disabled or if stuck with SSE2. */ #if (defined(__AVX2__) || defined(__SSSE3__)) && !defined(DISABLE_DOT_PROD) #define DOT_PROD @@ -673,13 +677,209 @@ static inline void sparse_sgemv_accum16(float *out, const float *weights, int ro } } +static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) +{ + int i, j; + for (i=0;i<rows;i+=16) + { + float *y; + __m256 vy0, vy8; + y = &out[i]; + vy0 = _mm256_setzero_ps(); + vy8 = _mm256_setzero_ps(); + for (j=0;j<cols;j++) + { + __m256 vxj; + __m256 vw; + vxj = _mm256_broadcast_ss(&x[j]); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]); + vy8 = _mm256_fmadd_ps(vw, vxj, vy8); + } + _mm256_storeu_ps (&y[0], vy0); + _mm256_storeu_ps (&y[8], vy8); + } +} + +static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x) +{ + int i, j; + for (i=0;i<rows;i+=8) + { + float *y; + int cols; + __m256 vy0; + y = &out[i]; + vy0 = _mm256_setzero_ps(); + cols = *idx++; + for (j=0;j<cols;j++) + { + int id; + __m256 vxj; + __m256 vw; + id = *idx++; + vxj = _mm256_broadcast_ss(&x[id]); + vw = _mm256_loadu_ps(&weights[0]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vxj = _mm256_broadcast_ss(&x[id+1]); + vw = _mm256_loadu_ps(&weights[8]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vxj = _mm256_broadcast_ss(&x[id+2]); + vw = _mm256_loadu_ps(&weights[16]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vxj = _mm256_broadcast_ss(&x[id+3]); + vw = _mm256_loadu_ps(&weights[24]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + weights += 32; + } + _mm256_storeu_ps (&y[0], vy0); + } +} + +static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x) +{ + __m256i ones; + int i, j; + unsigned char x[MAX_INPUTS]; + ones = _mm256_set1_epi16(1); + /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/ + vector_ps_to_epi8(x, _x, cols); + for (i=0;i<rows;i+=8) + { + int colblocks; + __m256i vy0; + __m256 vout; + colblocks = *idx++; + vy0 = _mm256_setzero_si256(); + j=0; +#if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */ + for (;j<colblocks-3;j+=4) + { + __m256i tmp; + __m256i vxj; + __m256i vw; + vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); + vw = _mm256_loadu_si256((const __m256i *)w); + tmp = _mm256_maddubs_epi16(vxj, vw); + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); + vw = _mm256_loadu_si256((const __m256i *)w); + tmp = _mm256_maddubs_epi16(vxj, vw); + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); + vw = _mm256_loadu_si256((const __m256i *)w); + tmp = _mm256_maddubs_epi16(vxj, vw); + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); + vw = _mm256_loadu_si256((const __m256i *)w); + tmp = _mm256_maddubs_epi16(vxj, vw); + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + } +#endif + for (;j<colblocks;j++) + { + __m256i tmp; + __m256i vxj; + __m256i vw; + int pos; + pos = (*idx++); + vxj = _mm256_set1_epi32(*(int*)&x[pos]); + vw = _mm256_loadu_si256((const __m256i *)w); + tmp = _mm256_maddubs_epi16(vxj, vw); + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + } + vout = _mm256_cvtepi32_ps(vy0); + vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i])); + _mm256_storeu_ps(&_out[i], vout); + } +} +static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x) +{ + __m256i ones; + int i, j; + unsigned char x[MAX_INPUTS]; + ones = _mm256_set1_epi16(1); + /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/ + vector_ps_to_epi8(x, _x, cols); + for (i=0;i<rows;i+=8) + { + __m256i vy0; + __m256 vout; + vy0 = _mm256_setzero_si256(); + j=0; +#if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */ + for (;j<cols-12;j+=16) + { + __m256i tmp; + __m256i vxj; + __m256i vw; + vxj = _mm256_set1_epi32(*(int*)&x[j]); + vw = _mm256_loadu_si256((const __m256i *)w); + tmp = _mm256_maddubs_epi16(vxj, vw); + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + vxj = _mm256_set1_epi32(*(int*)&x[j+4]); + vw = _mm256_loadu_si256((const __m256i *)w); + tmp = _mm256_maddubs_epi16(vxj, vw); + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + vxj = _mm256_set1_epi32(*(int*)&x[j+8]); + vw = _mm256_loadu_si256((const __m256i *)w); + tmp = _mm256_maddubs_epi16(vxj, vw); + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + vxj = _mm256_set1_epi32(*(int*)&x[j+12]); + vw = _mm256_loadu_si256((const __m256i *)w); + tmp = _mm256_maddubs_epi16(vxj, vw); + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + } +#endif + for (;j<cols;j+=4) + { + __m256i tmp; + __m256i vxj; + __m256i vw; + vxj = _mm256_set1_epi32(*(int*)&x[j]); + vw = _mm256_loadu_si256((const __m256i *)w); + tmp = _mm256_maddubs_epi16(vxj, vw); + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + } + vout = _mm256_cvtepi32_ps(vy0); + vout = _mm256_mul_ps(vout, _mm256_loadu_ps(&scale[i])); + _mm256_storeu_ps(&_out[i], vout); + } +} + + #ifdef DOT_PROD #define USE_SU_BIAS typedef signed char qweight; - -#define MAX_INPUTS (2048) #define MAX_OUTPUTS (8192) diff --git a/dnn/vec_neon.h b/dnn/vec_neon.h index 61d63d4c..4a56a62d 100644 --- a/dnn/vec_neon.h +++ b/dnn/vec_neon.h @@ -43,6 +43,9 @@ typedef signed char qweight; typedef float qweight; #endif +/* Just so it compiles when those functions aren't needed. */ +static inline void sgemv16x1(float *, const float *, int , int , int , const float *) {} +static inline void sparse_sgemv8x4(float *, const float *, const int *, int , const float *) {} #ifndef LPCNET_TEST static inline float32x4_t exp4_approx(float32x4_t x) { @@ -295,6 +298,76 @@ static inline int32x4_t vdotprod(int32x4_t acc, int8x16_t a, int8x16_t b) } #endif +static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x) +{ + int i, j; + opus_int8 x[MAX_INPUTS]; + const float32x4_t const127 = vdupq_n_f32(127.); + for (i=0;i<cols;i+=8) { + int32x4_t xi0, xi4; + int16x8_t x_short; + xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i]))); + xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4]))); + x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4)); + vst1_s8(&x[i], vmovn_s16(x_short)); + } + for (i=0;i<rows;i+=8) + { + int32x4_t acc0, acc1; + acc0 = vdupq_n_s32(0); + acc1 = vdupq_n_s32(0); + for (j=0;j<cols;j+=4) + { + int8x16_t vw0, vw1, vx; + vx = (int8x16_t)vld1q_dup_s32((int*)&x[j]); + vw0 = vld1q_s8(w); + vw1 = vld1q_s8(&w[16]); + acc0 = vdotprod(acc0, vw0, vx); + acc1 = vdotprod(acc1, vw1, vx); + w += 32; + } + vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0))); + vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1))); + } +} + +static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x) +{ + int i, j; + opus_int8 x[MAX_INPUTS]; + const float32x4_t const127 = vdupq_n_f32(127.); + for (i=0;i<cols;i+=8) { + int32x4_t xi0, xi4; + int16x8_t x_short; + xi0 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i]))); + xi4 = vcvtnq_s32_f32(vmulq_f32(const127, vld1q_f32(&_x[i+4]))); + x_short = vcombine_s16(vmovn_s32(xi0), vmovn_s32(xi4)); + vst1_s8(&x[i], vmovn_s16(x_short)); + } + for (i=0;i<rows;i+=8) + { + int colblocks; + int32x4_t acc0, acc1; + acc0 = vdupq_n_s32(0); + acc1 = vdupq_n_s32(0); + colblocks = *idx++; + for (j=0;j<colblocks;j++) + { + int pos; + pos = (*idx++); + int8x16_t vw0, vw1, vx; + vx = (int8x16_t)vld1q_dup_s32((int*)&x[pos]); + vw0 = vld1q_s8(w); + vw1 = vld1q_s8(&w[16]); + acc0 = vdotprod(acc0, vw0, vx); + acc1 = vdotprod(acc1, vw1, vx); + w += 32; + } + vst1q_f32(&_out[i], vmulq_f32(vld1q_f32(&scale[i]), vcvtq_f32_s32(acc0))); + vst1q_f32(&_out[i+4], vmulq_f32(vld1q_f32(&scale[i+4]), vcvtq_f32_s32(acc1))); + } +} + static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, int col_stride, const float *_x) { int i, j; |