diff options
author | Jean-Marc Valin <jmvalin@amazon.com> | 2021-07-15 23:06:56 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@amazon.com> | 2021-07-15 23:06:56 +0300 |
commit | c74330e85035c2d47f101fa33797e93ecaaebcd3 (patch) | |
tree | 5030d4200d55aeb04dde3f466ef26f3d4c737d95 /dnn/nnet.c | |
parent | 0d53fad50dfc9f5d023a9d29db596a4f534a23e1 (diff) |
Pre-compute GRU B conditioning
Adapted from PR: https://github.com/mozilla/LPCNet/pull/134
by zhuxiaoxu <zhuxiaoxu@ainirobot.com>
but had to be reworked due to previous weight quantization changes.
Diffstat (limited to 'dnn/nnet.c')
-rw-r--r-- | dnn/nnet.c | 44 |
1 files changed, 44 insertions, 0 deletions
@@ -296,6 +296,50 @@ void compute_gru2(const GRULayer *gru, float *state, const float *input) state[i] = h[i]; } +void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *state, const float *input) +{ + int i; + int N, M; + int stride; + float zrh[3*MAX_RNN_NEURONS]; + float recur[3*MAX_RNN_NEURONS]; + float *z; + float *r; + float *h; + M = gru->nb_inputs; + N = gru->nb_neurons; + z = zrh; + r = &zrh[N]; + h = &zrh[2*N]; + celt_assert(gru->nb_neurons <= MAX_RNN_NEURONS); + celt_assert(input != state); + celt_assert(gru->reset_after); + stride = 3*N; + /* Compute update gate. */ +#ifdef USE_SU_BIAS + for (i=0;i<3*N;i++) + zrh[i] = gru->subias[i] + gru_b_condition[i]; +#else + for (i=0;i<3*N;i++) + zrh[i] = gru->bias[i] + gru_b_condition[i]; +#endif + sgemv_accum8x4(zrh, gru->input_weights, 3*N, M, stride, input); + for (i=0;i<3*N;i++) + recur[i] = gru->bias[3*N + i]; + sgemv_accum(recur, gru->recurrent_weights, 3*N, N, stride, state); + for (i=0;i<2*N;i++) + zrh[i] += recur[i]; + compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID); + for (i=0;i<N;i++) + h[i] += recur[2*N+i]*r[i]; + compute_activation(h, h, N, gru->activation); + for (i=0;i<N;i++) + h[i] = z[i]*state[i] + (1-z[i])*h[i]; + for (i=0;i<N;i++) + state[i] = h[i]; +} + + void compute_gru3(const GRULayer *gru, float *state, const float *input) { int i; |