diff options
-rw-r--r-- | celt/celt.h | 3 | ||||
-rwxr-xr-x | scripts/dump_rnn.py | 57 | ||||
-rwxr-xr-x | scripts/rnn_train.py | 67 | ||||
-rw-r--r-- | src/analysis.c | 271 | ||||
-rw-r--r-- | src/analysis.h | 15 | ||||
-rw-r--r-- | src/mlp.c | 152 | ||||
-rw-r--r-- | src/mlp.h | 35 | ||||
-rw-r--r-- | src/mlp_data.c | 325 | ||||
-rw-r--r-- | src/mlp_train.c | 501 | ||||
-rw-r--r-- | src/mlp_train.h | 86 | ||||
-rw-r--r-- | src/opus_encoder.c | 11 |
11 files changed, 580 insertions, 943 deletions
diff --git a/celt/celt.h b/celt/celt.h index 70175301..f73f29dd 100644 --- a/celt/celt.h +++ b/celt/celt.h @@ -59,7 +59,8 @@ typedef struct { float noisiness; float activity; float music_prob; - float vad_prob; + float music_prob_min; + float music_prob_max; int bandwidth; float activity_probability; /* Store as Q6 char to save space. */ diff --git a/scripts/dump_rnn.py b/scripts/dump_rnn.py new file mode 100755 index 00000000..dd66403b --- /dev/null +++ b/scripts/dump_rnn.py @@ -0,0 +1,57 @@ +#!/usr/bin/python + +from __future__ import print_function + +from keras.models import Sequential +from keras.layers import Dense +from keras.layers import LSTM +from keras.layers import GRU +from keras.models import load_model +from keras import backend as K + +import numpy as np + +def printVector(f, vector, name): + v = np.reshape(vector, (-1)); + #print('static const float ', name, '[', len(v), '] = \n', file=f) + f.write('static const opus_int16 {}[{}] = {{\n '.format(name, len(v))) + for i in range(0, len(v)): + f.write('{}'.format(int(round(8192*v[i])))) + if (i!=len(v)-1): + f.write(',') + else: + break; + if (i%8==7): + f.write("\n ") + else: + f.write(" ") + #print(v, file=f) + f.write('\n};\n\n') + return; + +def binary_crossentrop2(y_true, y_pred): + return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1) + + +model = load_model("weights.hdf5", custom_objects={'binary_crossentrop2': binary_crossentrop2}) + +weights = model.get_weights() + +f = open('rnn_weights.c', 'w') + +f.write('/*This file is automatically generated from a Keras model*/\n\n') +f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "mlp.h"\n\n') + +printVector(f, weights[0], 'layer0_weights') +printVector(f, weights[1], 'layer0_bias') +printVector(f, weights[2], 'layer1_weights') +printVector(f, weights[3], 'layer1_recur_weights') +printVector(f, weights[4], 'layer1_bias') +printVector(f, weights[5], 'layer2_weights') +printVector(f, weights[6], 'layer2_bias') + +f.write('const DenseLayer layer0 = {\n layer0_bias,\n layer0_weights,\n 25, 16, 0\n};\n\n') +f.write('const GRULayer layer1 = {\n layer1_bias,\n layer1_weights,\n layer1_recur_weights,\n 16, 12\n};\n\n') +f.write('const DenseLayer layer2 = {\n layer2_bias,\n layer2_weights,\n 12, 2, 1\n};\n\n') + +f.close() diff --git a/scripts/rnn_train.py b/scripts/rnn_train.py new file mode 100755 index 00000000..ffdaa1e7 --- /dev/null +++ b/scripts/rnn_train.py @@ -0,0 +1,67 @@ +#!/usr/bin/python + +from __future__ import print_function + +from keras.models import Sequential +from keras.models import Model +from keras.layers import Input +from keras.layers import Dense +from keras.layers import LSTM +from keras.layers import GRU +from keras.layers import SimpleRNN +from keras.layers import Dropout +from keras import losses +import h5py + +from keras import backend as K +import numpy as np + +def binary_crossentrop2(y_true, y_pred): + return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1) + +print('Build model...') +#model = Sequential() +#model.add(Dense(16, activation='tanh', input_shape=(None, 25))) +#model.add(GRU(12, dropout=0.0, recurrent_dropout=0.0, activation='tanh', recurrent_activation='sigmoid', return_sequences=True)) +#model.add(Dense(2, activation='sigmoid')) + +main_input = Input(shape=(None, 25), name='main_input') +x = Dense(16, activation='tanh')(main_input) +x = GRU(12, dropout=0.1, recurrent_dropout=0.1, activation='tanh', recurrent_activation='sigmoid', return_sequences=True)(x) +x = Dense(2, activation='sigmoid')(x) +model = Model(inputs=main_input, outputs=x) + +batch_size = 64 + +print('Loading data...') +with h5py.File('features.h5', 'r') as hf: + all_data = hf['features'][:] +print('done.') + +window_size = 1500 + +nb_sequences = len(all_data)/window_size +print(nb_sequences, ' sequences') +x_train = all_data[:nb_sequences*window_size, :-2] +x_train = np.reshape(x_train, (nb_sequences, window_size, 25)) + +y_train = np.copy(all_data[:nb_sequences*window_size, -2:]) +y_train = np.reshape(y_train, (nb_sequences, window_size, 2)) + +all_data = 0; +x_train = x_train.astype('float32') +y_train = y_train.astype('float32') + +print(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape) + +# try using different optimizers and different optimizer configs +model.compile(loss=binary_crossentrop2, + optimizer='adam', + metrics=['binary_accuracy']) + +print('Train...') +model.fit(x_train, y_train, + batch_size=batch_size, + epochs=200, + validation_data=(x_train, y_train)) +model.save("newweights.hdf5") diff --git a/src/analysis.c b/src/analysis.c index f4160e4b..1d6dd829 100644 --- a/src/analysis.c +++ b/src/analysis.c @@ -50,6 +50,8 @@ #ifndef DISABLE_FLOAT_API +#define TRANSITION_PENALTY 10 + static const float dct_table[128] = { 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, 0.250000f, @@ -224,19 +226,22 @@ void tonality_analysis_reset(TonalityAnalysisState *tonal) /* Clear non-reusable fields. */ char *start = (char*)&tonal->TONALITY_ANALYSIS_RESET_START; OPUS_CLEAR(start, sizeof(TonalityAnalysisState) - (start - (char*)tonal)); - tonal->music_confidence = .9f; - tonal->speech_confidence = .1f; } void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int len) { int pos; int curr_lookahead; - float psum; float tonality_max; float tonality_avg; int tonality_count; int i; + int pos0; + float prob_avg; + float prob_count; + float prob_min, prob_max; + float vad_prob; + int mpos, vpos; pos = tonal->read_pos; curr_lookahead = tonal->write_pos-tonal->read_pos; @@ -254,6 +259,7 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int pos--; if (pos<0) pos = DETECT_SIZE-1; + pos0 = pos; OPUS_COPY(info_out, &tonal->info[pos], 1); tonality_max = tonality_avg = info_out->tonality; tonality_count = 1; @@ -270,6 +276,107 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int tonality_count++; } info_out->tonality = MAX32(tonality_avg/tonality_count, tonality_max-.2f); + + mpos = vpos = pos0; + /* If we have enough look-ahead, compensate for the ~5-frame delay in the music prob and + ~1 frame delay in the VAD prob. */ + if (curr_lookahead > 15) + { + mpos += 5; + if (mpos>=DETECT_SIZE) + mpos -= DETECT_SIZE; + vpos += 1; + if (vpos>=DETECT_SIZE) + vpos -= DETECT_SIZE; + } + + /* The following calculations attempt to minimize a "badness function" + for the transition. When switching from speech to music, the badness + of switching at frame k is + b_k = S*v_k + \sum_{i=0}^{k-1} v_i*(p_i - T) + where + v_i is the activity probability (VAD) at frame i, + p_i is the music probability at frame i + T is the probability threshold for switching + S is the penalty for switching during active audio rather than silence + the current frame has index i=0 + + Rather than apply badness to directly decide when to switch, what we compute + instead is the threshold for which the optimal switching point is now. When + considering whether to switch now (frame 0) or at frame k, we have: + S*v_0 = S*v_k + \sum_{i=0}^{k-1} v_i*(p_i - T) + which gives us: + T = ( \sum_{i=0}^{k-1} v_i*p_i + S*(v_k-v_0) ) / ( \sum_{i=0}^{k-1} v_i ) + We take the min threshold across all positive values of k (up to the maximum + amount of lookahead we have) to give us the threshold for which the current + frame is the optimal switch point. + + The last step is that we need to consider whether we want to switch at all. + For that we use the average of the music probability over the entire window. + If the threshold is higher than that average we're not going to + switch, so we compute a min with the average as well. The result of all these + min operations is music_prob_min, which gives the threshold for switching to music + if we're currently encoding for speech. + + We do the exact opposite to compute music_prob_max which is used for switching + from music to speech. + */ + prob_min = 1.f; + prob_max = 0.f; + vad_prob = tonal->info[vpos].activity_probability; + prob_count = MAX16(.1f, vad_prob); + prob_avg = MAX16(.1f, vad_prob)*tonal->info[mpos].music_prob; + while (1) + { + float pos_vad; + mpos++; + if (mpos==DETECT_SIZE) + mpos = 0; + if (mpos == tonal->write_pos) + break; + vpos++; + if (vpos==DETECT_SIZE) + vpos = 0; + if (vpos == tonal->write_pos) + break; + pos_vad = tonal->info[vpos].activity_probability; + prob_min = MIN16((prob_avg - TRANSITION_PENALTY*(vad_prob - pos_vad))/prob_count, prob_min); + prob_max = MAX16((prob_avg + TRANSITION_PENALTY*(vad_prob - pos_vad))/prob_count, prob_max); + prob_count += MAX16(.1f, pos_vad); + prob_avg += MAX16(.1f, pos_vad)*tonal->info[mpos].music_prob; + } + info_out->music_prob = prob_avg/prob_count; + prob_min = MIN16(prob_avg/prob_count, prob_min); + prob_max = MAX16(prob_avg/prob_count, prob_max); + prob_min = MAX16(prob_min, 0.f); + prob_max = MIN16(prob_max, 1.f); + + /* If we don't have enough look-ahead, do our best to make a decent decision. */ + if (curr_lookahead < 10) + { + float pmin, pmax; + pmin = prob_min; + pmax = prob_max; + pos = pos0; + /* Look for min/max in the past. */ + for (i=0;i<IMIN(tonal->count-1, 15);i++) + { + pos--; + if (pos < 0) + pos = DETECT_SIZE-1; + pmin = MIN16(pmin, tonal->info[pos].music_prob); + pmax = MAX16(pmax, tonal->info[pos].music_prob); + } + /* Bias against switching on active audio. */ + pmin = MAX16(0.f, pmin - .1f*vad_prob); + pmax = MIN16(1.f, pmax + .1f*vad_prob); + prob_min += (1.f-.1f*curr_lookahead)*(pmin - prob_min); + prob_max += (1.f-.1f*curr_lookahead)*(pmax - prob_max); + } + info_out->music_prob_min = prob_min; + info_out->music_prob_max = prob_max; + + /* printf("%f %f %f %f %f\n", prob_min, prob_max, prob_avg/prob_count, vad_prob, info_out->music_prob); */ tonal->read_subframe += len/(tonal->Fs/400); while (tonal->read_subframe>=8) { @@ -278,21 +385,6 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int } if (tonal->read_pos>=DETECT_SIZE) tonal->read_pos-=DETECT_SIZE; - - /* The -1 is to compensate for the delay in the features themselves. */ - curr_lookahead = IMAX(curr_lookahead-1, 0); - - psum=0; - /* Summing the probability of transition patterns that involve music at - time (DETECT_SIZE-curr_lookahead-1) */ - for (i=0;i<DETECT_SIZE-curr_lookahead;i++) - psum += tonal->pmusic[i]; - for (;i<DETECT_SIZE;i++) - psum += tonal->pspeech[i]; - psum = psum*tonal->music_confidence + (1-psum)*tonal->speech_confidence; - /*printf("%f %f %f %f %f\n", psum, info_out->music_prob, info_out->vad_prob, info_out->activity_probability, info_out->tonality);*/ - - info_out->music_prob = psum; } static const float std_feature_bias[9] = { @@ -352,6 +444,7 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt float band_log2[NB_TBANDS+1]; float leakage_from[NB_TBANDS+1]; float leakage_to[NB_TBANDS+1]; + float layer_out[MAX_NEURONS]; SAVE_STACK; alpha = 1.f/IMIN(10, 1+tonal->count); @@ -368,12 +461,6 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt offset = 3*offset/2; } - if (tonal->count<4) { - if (tonal->application == OPUS_APPLICATION_VOIP) - tonal->music_prob = .1f; - else - tonal->music_prob = .625f; - } kfft = celt_mode->mdct.kfft[0]; if (tonal->count==0) tonal->mem_fill = 240; @@ -761,139 +848,17 @@ static void tonality_analysis(TonalityAnalysisState *tonal, const CELTMode *celt features[23] = info->tonality_slope + 0.069216f; features[24] = tonal->lowECount - 0.067930f; - mlp_process(&net, features, frame_probs); - frame_probs[0] = .5f*(frame_probs[0]+1); - /* Curve fitting between the MLP probability and the actual probability */ - /*frame_probs[0] = .01f + 1.21f*frame_probs[0]*frame_probs[0] - .23f*(float)pow(frame_probs[0], 10);*/ - /* Probability of active audio (as opposed to silence) */ - frame_probs[1] = .5f*frame_probs[1]+.5f; - frame_probs[1] *= frame_probs[1]; + compute_dense(&layer0, layer_out, features); + compute_gru(&layer1, tonal->rnn_state, layer_out); + compute_dense(&layer2, frame_probs, tonal->rnn_state); /* Probability of speech or music vs noise */ info->activity_probability = frame_probs[1]; + /* It seems like the RNN tends to have a bias towards speech and this + warping of the probabilities compensates for it. */ + info->music_prob = frame_probs[0] * (2 - frame_probs[0]); - /*printf("%f %f\n", frame_probs[0], frame_probs[1]);*/ - { - /* Probability of state transition */ - float tau; - /* Represents independence of the MLP probabilities, where - beta=1 means fully independent. */ - float beta; - /* Denormalized probability of speech (p0) and music (p1) after update */ - float p0, p1; - /* Probabilities for "all speech" and "all music" */ - float s0, m0; - /* Probability sum for renormalisation */ - float psum; - /* Instantaneous probability of speech and music, with beta pre-applied. */ - float speech0; - float music0; - float p, q; - - /* More silence transitions for speech than for music. */ - tau = .001f*tonal->music_prob + .01f*(1-tonal->music_prob); - p = MAX16(.05f,MIN16(.95f,frame_probs[1])); - q = MAX16(.05f,MIN16(.95f,tonal->vad_prob)); - beta = .02f+.05f*ABS16(p-q)/(p*(1-q)+q*(1-p)); - /* p0 and p1 are the probabilities of speech and music at this frame - using only information from previous frame and applying the - state transition model */ - p0 = (1-tonal->vad_prob)*(1-tau) + tonal->vad_prob *tau; - p1 = tonal->vad_prob *(1-tau) + (1-tonal->vad_prob)*tau; - /* We apply the current probability with exponent beta to work around - the fact that the probability estimates aren't independent. */ - p0 *= (float)pow(1-frame_probs[1], beta); - p1 *= (float)pow(frame_probs[1], beta); - /* Normalise the probabilities to get the Marokv probability of music. */ - tonal->vad_prob = p1/(p0+p1); - info->vad_prob = tonal->vad_prob; - /* Consider that silence has a 50-50 probability of being speech or music. */ - frame_probs[0] = tonal->vad_prob*frame_probs[0] + (1-tonal->vad_prob)*.5f; - - /* One transition every 3 minutes of active audio */ - tau = .0001f; - /* Adapt beta based on how "unexpected" the new prob is */ - p = MAX16(.05f,MIN16(.95f,frame_probs[0])); - q = MAX16(.05f,MIN16(.95f,tonal->music_prob)); - beta = .02f+.05f*ABS16(p-q)/(p*(1-q)+q*(1-p)); - /* p0 and p1 are the probabilities of speech and music at this frame - using only information from previous frame and applying the - state transition model */ - p0 = (1-tonal->music_prob)*(1-tau) + tonal->music_prob *tau; - p1 = tonal->music_prob *(1-tau) + (1-tonal->music_prob)*tau; - /* We apply the current probability with exponent beta to work around - the fact that the probability estimates aren't independent. */ - p0 *= (float)pow(1-frame_probs[0], beta); - p1 *= (float)pow(frame_probs[0], beta); - /* Normalise the probabilities to get the Marokv probability of music. */ - tonal->music_prob = p1/(p0+p1); - info->music_prob = tonal->music_prob; - - /*printf("%f %f %f %f\n", frame_probs[0], frame_probs[1], tonal->music_prob, tonal->vad_prob);*/ - /* This chunk of code deals with delayed decision. */ - psum=1e-20f; - /* Instantaneous probability of speech and music, with beta pre-applied. */ - speech0 = (float)pow(1-frame_probs[0], beta); - music0 = (float)pow(frame_probs[0], beta); - if (tonal->count==1) - { - if (tonal->application == OPUS_APPLICATION_VOIP) - tonal->pmusic[0] = .1f; - else - tonal->pmusic[0] = .625f; - tonal->pspeech[0] = 1-tonal->pmusic[0]; - } - /* Updated probability of having only speech (s0) or only music (m0), - before considering the new observation. */ - s0 = tonal->pspeech[0] + tonal->pspeech[1]; - m0 = tonal->pmusic [0] + tonal->pmusic [1]; - /* Updates s0 and m0 with instantaneous probability. */ - tonal->pspeech[0] = s0*(1-tau)*speech0; - tonal->pmusic [0] = m0*(1-tau)*music0; - /* Propagate the transition probabilities */ - for (i=1;i<DETECT_SIZE-1;i++) - { - tonal->pspeech[i] = tonal->pspeech[i+1]*speech0; - tonal->pmusic [i] = tonal->pmusic [i+1]*music0; - } - /* Probability that the latest frame is speech, when all the previous ones were music. */ - tonal->pspeech[DETECT_SIZE-1] = m0*tau*speech0; - /* Probability that the latest frame is music, when all the previous ones were speech. */ - tonal->pmusic [DETECT_SIZE-1] = s0*tau*music0; - - /* Renormalise probabilities to 1 */ - for (i=0;i<DETECT_SIZE;i++) - psum += tonal->pspeech[i] + tonal->pmusic[i]; - psum = 1.f/psum; - for (i=0;i<DETECT_SIZE;i++) - { - tonal->pspeech[i] *= psum; - tonal->pmusic [i] *= psum; - } - psum = tonal->pmusic[0]; - for (i=1;i<DETECT_SIZE;i++) - psum += tonal->pspeech[i]; - - /* Estimate our confidence in the speech/music decisions */ - if (frame_probs[1]>.75) - { - if (tonal->music_prob>.9) - { - float adapt; - adapt = 1.f/(++tonal->music_confidence_count); - tonal->music_confidence_count = IMIN(tonal->music_confidence_count, 500); - tonal->music_confidence += adapt*MAX16(-.2f,frame_probs[0]-tonal->music_confidence); - } - if (tonal->music_prob<.1) - { - float adapt; - adapt = 1.f/(++tonal->speech_confidence_count); - tonal->speech_confidence_count = IMIN(tonal->speech_confidence_count, 500); - tonal->speech_confidence += adapt*MIN16(.2f,frame_probs[0]-tonal->speech_confidence); - } - } - } - tonal->last_music = tonal->music_prob>.5f; + /*printf("%f %f %f\n", frame_probs[0], frame_probs[1], info->music_prob);*/ #ifdef MLP_TRAINING for (i=0;i<25;i++) printf("%f ", features[i]); diff --git a/src/analysis.h b/src/analysis.h index cac51dfa..289c845e 100644 --- a/src/analysis.h +++ b/src/analysis.h @@ -30,6 +30,7 @@ #include "celt.h" #include "opus_private.h" +#include "mlp.h" #define NB_FRAMES 8 #define NB_TBANDS 18 @@ -64,28 +65,16 @@ typedef struct { float mem[32]; float cmean[8]; float std[9]; - float music_prob; - float vad_prob; float Etracker; float lowECount; int E_count; - int last_music; int count; int analysis_offset; - /** Probability of having speech for time i to DETECT_SIZE-1 (and music before). - pspeech[0] is the probability that all frames in the window are speech. */ - float pspeech[DETECT_SIZE]; - /** Probability of having music for time i to DETECT_SIZE-1 (and speech before). - pmusic[0] is the probability that all frames in the window are music. */ - float pmusic[DETECT_SIZE]; - float speech_confidence; - float music_confidence; - int speech_confidence_count; - int music_confidence_count; int write_pos; int read_pos; int read_subframe; float hp_ener_accum; + float rnn_state[MAX_NEURONS]; opus_val32 downmix_state[3]; AnalysisInfo info[DETECT_SIZE]; } TonalityAnalysisState; @@ -1,5 +1,5 @@ /* Copyright (c) 2008-2011 Octasic Inc. - Written by Jean-Marc Valin */ + 2012-2017 Jean-Marc Valin */ /* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions @@ -29,42 +29,13 @@ #include "config.h" #endif +#include <math.h> #include "opus_types.h" #include "opus_defines.h" - -#include <math.h> -#include "mlp.h" #include "arch.h" #include "tansig_table.h" -#define MAX_NEURONS 100 +#include "mlp.h" -#if 0 -static OPUS_INLINE opus_val16 tansig_approx(opus_val32 _x) /* Q19 */ -{ - int i; - opus_val16 xx; /* Q11 */ - /*double x, y;*/ - opus_val16 dy, yy; /* Q14 */ - /*x = 1.9073e-06*_x;*/ - if (_x>=QCONST32(8,19)) - return QCONST32(1.,14); - if (_x<=-QCONST32(8,19)) - return -QCONST32(1.,14); - xx = EXTRACT16(SHR32(_x, 8)); - /*i = lrint(25*x);*/ - i = SHR32(ADD32(1024,MULT16_16(25, xx)),11); - /*x -= .04*i;*/ - xx -= EXTRACT16(SHR32(MULT16_16(20972,i),8)); - /*x = xx*(1./2048);*/ - /*y = tansig_table[250+i];*/ - yy = tansig_table[250+i]; - /*y = yy*(1./16384);*/ - dy = 16384-MULT16_16_Q14(yy,yy); - yy = yy + MULT16_16_Q14(MULT16_16_Q11(xx,dy),(16384 - MULT16_16_Q11(yy,xx))); - return yy; -} -#else -/*extern const float tansig_table[501];*/ static OPUS_INLINE float tansig_approx(float x) { int i; @@ -92,54 +63,79 @@ static OPUS_INLINE float tansig_approx(float x) y = y + x*dy*(1 - y*x); return sign*y; } -#endif -#if 0 -void mlp_process(const MLP *m, const opus_val16 *in, opus_val16 *out) +static OPUS_INLINE float sigmoid_approx(float x) { - int j; - opus_val16 hidden[MAX_NEURONS]; - const opus_val16 *W = m->weights; - /* Copy to tmp_in */ - for (j=0;j<m->topo[1];j++) - { - int k; - opus_val32 sum = SHL32(EXTEND32(*W++),8); - for (k=0;k<m->topo[0];k++) - sum = MAC16_16(sum, in[k],*W++); - hidden[j] = tansig_approx(sum); - } - for (j=0;j<m->topo[2];j++) - { - int k; - opus_val32 sum = SHL32(EXTEND32(*W++),14); - for (k=0;k<m->topo[1];k++) - sum = MAC16_16(sum, hidden[k], *W++); - out[j] = tansig_approx(EXTRACT16(PSHR32(sum,17))); - } + return .5 + .5*tansig_approx(.5*x); } -#else -void mlp_process(const MLP *m, const float *in, float *out) + +void compute_dense(const DenseLayer *layer, float *output, const float *input) { - int j; - float hidden[MAX_NEURONS]; - const float *W = m->weights; - /* Copy to tmp_in */ - for (j=0;j<m->topo[1];j++) - { - int k; - float sum = *W++; - for (k=0;k<m->topo[0];k++) - sum = sum + in[k]**W++; - hidden[j] = tansig_approx(sum); - } - for (j=0;j<m->topo[2];j++) - { - int k; - float sum = *W++; - for (k=0;k<m->topo[1];k++) - sum = sum + hidden[k]**W++; - out[j] = tansig_approx(sum); - } + int i, j; + int N, M; + int stride; + M = layer->nb_inputs; + N = layer->nb_neurons; + stride = N; + for (i=0;i<N;i++) + { + /* Compute update gate. */ + float sum = layer->bias[i]; + for (j=0;j<M;j++) + sum += layer->input_weights[j*stride + i]*input[j]; + output[i] = WEIGHTS_SCALE*sum; + } + if (layer->sigmoid) { + for (i=0;i<N;i++) + output[i] = sigmoid_approx(output[i]); + } else { + for (i=0;i<N;i++) + output[i] = tansig_approx(output[i]); + } } -#endif + +void compute_gru(const GRULayer *gru, float *state, const float *input) +{ + int i, j; + int N, M; + int stride; + float z[MAX_NEURONS]; + float r[MAX_NEURONS]; + float h[MAX_NEURONS]; + M = gru->nb_inputs; + N = gru->nb_neurons; + stride = 3*N; + for (i=0;i<N;i++) + { + /* Compute update gate. */ + float sum = gru->bias[i]; + for (j=0;j<M;j++) + sum += gru->input_weights[j*stride + i]*input[j]; + for (j=0;j<N;j++) + sum += gru->recurrent_weights[j*stride + i]*state[j]; + z[i] = sigmoid_approx(WEIGHTS_SCALE*sum); + } + for (i=0;i<N;i++) + { + /* Compute reset gate. */ + float sum = gru->bias[N + i]; + for (j=0;j<M;j++) + sum += gru->input_weights[N + j*stride + i]*input[j]; + for (j=0;j<N;j++) + sum += gru->recurrent_weights[N + j*stride + i]*state[j]; + r[i] = sigmoid_approx(WEIGHTS_SCALE*sum); + } + for (i=0;i<N;i++) + { + /* Compute output. */ + float sum = gru->bias[2*N + i]; + for (j=0;j<M;j++) + sum += gru->input_weights[2*N + j*stride + i]*input[j]; + for (j=0;j<N;j++) + sum += gru->recurrent_weights[2*N + j*stride + i]*state[j]*r[j]; + h[i] = z[i]*state[i] + (1-z[i])*tansig_approx(WEIGHTS_SCALE*sum); + } + for (i=0;i<N;i++) + state[i] = h[i]; +} + @@ -1,5 +1,4 @@ -/* Copyright (c) 2008-2011 Octasic Inc. - Written by Jean-Marc Valin */ +/* Copyright (c) 2017 Jean-Marc Valin */ /* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions @@ -28,16 +27,34 @@ #ifndef _MLP_H_ #define _MLP_H_ -#include "arch.h" +#include "opus_types.h" + +#define WEIGHTS_SCALE (1.f/8192) + +#define MAX_NEURONS 20 typedef struct { - int layers; - const int *topo; - const float *weights; -} MLP; + const opus_int16 *bias; + const opus_int16 *input_weights; + int nb_inputs; + int nb_neurons; + int sigmoid; +} DenseLayer; + +typedef struct { + const opus_int16 *bias; + const opus_int16 *input_weights; + const opus_int16 *recurrent_weights; + int nb_inputs; + int nb_neurons; +} GRULayer; + +extern const DenseLayer layer0; +extern const GRULayer layer1; +extern const DenseLayer layer2; -extern const MLP net; +void compute_dense(const DenseLayer *layer, float *output, const float *input); -void mlp_process(const MLP *m, const float *in, float *out); +void compute_gru(const GRULayer *gru, float *state, const float *input); #endif /* _MLP_H_ */ diff --git a/src/mlp_data.c b/src/mlp_data.c index a819880b..10b787d4 100644 --- a/src/mlp_data.c +++ b/src/mlp_data.c @@ -1,112 +1,235 @@ +/*This file is automatically generated from a Keras model*/ + #ifdef HAVE_CONFIG_H #include "config.h" #endif #include "mlp.h" -/* RMS error was 0.280492, seed was 1480478173 */ -/* 0.005976 0.031821 (0.280494 0.280492) done */ +static const opus_int16 layer0_weights[400] = { + -249, 690, -57, 358, -560, -144, 186, 75, + -804, -1176, -433, -78, 125, -1141, -857, -2, + 1892, 91, 976, 1112, -1636, -73, -1740, -1604, + 2012, -1043, 828, 230, 8698, -92, -665, -747, + 1530, -1315, 2317, 697, 2885, -1399, 2661, 483, + -1628, 502, -592, 299, 3910, -781, 2738, 1338, + -1562, -149, 3468, 1448, 3057, 1202, 2098, 2777, + -1540, -3018, -249, 4656, 2508, 373, 2412, -776, + 7160, -519, -917, -155, -1311, -1239, -637, -1245, + -1450, 1963, 3297, 1489, 1582, -123, -549, 1004, + -4085, 8792, -2145, 220, 2741, 624, -3560, 106, + -2476, 661, 1601, 2177, -1793, -623, 3349, 1959, + 2777, -4635, 451, -996, -3260, -665, 1103, 201, + -2566, 3033, 1065, 1866, 989, -102, -1328, 126, + 1, 4365, 82, 2355, -1011, -107, -5323, -1758, + -691, 1744, 683, -2732, 1309, -1135, -726, 1071, + 9423, 1120, -705, -188, -200, -2668, -750, -1839, + 793, 718, -1011, 222, 567, 31, -1520, 3142, + -5491, -3549, -2718, -276, 2078, -706, -779, -2304, + -2983, -660, 1664, -999, -3297, -1200, 1017, -499, + -764, 3215, -720, 255, 1539, -1142, -3604, -351, + -982, 846, 4069, 481, 5673, -1184, -2883, -1387, + 519, -1617, 315, 1875, -119, 2383, 1141, 1583, + 1013, -531, 349, 121, -139, 327, 531, 611, + 853, 1118, 2013, -294, -1150, 693, 531, 583, + -1506, 224, -818, 655, 1981, 1056, -2327, -1457, + -2846, 3779, 1230, -2587, -191, 1647, -3484, -3450, + -3384, -93, -1028, 825, 868, 38, 557, -125, + 1830, 1981, 1063, 9906, -455, 172, -1788, 4417, + 472, -1398, -4638, 999, -6158, 1943, 4703, -2986, + -938, 3053, -631, -384, 848, -3909, 1352, -2362, + -2306, 515, 2385, -2373, -1642, 582, -262, -571, + 8, 1615, -2501, 1225, -660, -857, -522, 2419, + 654, -1137, 67, -890, 83, 23, 2166, 524, + -978, 5330, 1237, 1163, -2251, -142, -2331, 3034, + 395, -1799, 944, 1978, -2788, 1324, 3271, -4643, + -1313, -2472, 1296, -2316, -1803, -10224, -8577, 8271, + -1920, -3366, -1704, 3250, -2514, 11995, 6655, 4298, + 1046, 483, 651, -901, -1417, 804, 396, -2617, + 1000, 2265, 5354, -1050, 2505, 41, 3928, 1878, + -21057, 12783, 32767, -8139, -32768, 1106, -12076, -26511, + -3484, 24604, 8938, 22944, -9490, -6208, -22142, 23250, + -12708, -299, 14432, -2311, -11941, -797, -3287, -4744, + -10758, 10226, -851, 8565, 4104, -4002, 4456, 12642, + 1685, -7093, -997, 16081, 814, -5316, -13491, 12766, + -1637, -213, 7271, -3037, -6772, 3053, -12425, -6955, + 12553, 7635, -32768, -18611, 22929, 3056, 11196, 5202, + 31582, 5741, -22206, 6145, -673, -25488, -7005, -16479, + 10693, -11369, -10848, -1895, 8051, 7360, 1067, -220, + 6643, 17077, -12356, 3288, 4619, 9751, -656, -1217 +}; + +static const opus_int16 layer0_bias[16] = { + -164, 2802, -2100, 410, 4003, -888, 3010, -644, + 4499, -121, 3753, -1606, -4855, -1828, -682, -79 +}; + +static const opus_int16 layer1_weights[576] = { + 543, 2150, 143, 1450, 7898, -3201, -2648, -4311, + 7028, -2608, 1844, 126, -858, 4572, -347, -11298, + 11315, -4344, 1858, -5906, -5962, 2847, -3894, -1496, + 5309, -651, -3143, -3141, 429, -679, -1524, -1966, + -1175, 2917, 97, -1094, -3186, 4346, 832, 3726, + 5452, 1371, 505, -1282, -435, 3438, 691, -2692, + -872, -1332, 3722, 841, -1081, 2414, -1275, 2131, + -7351, -962, -2295, 1141, 2810, -839, 1444, -1005, + 3900, 1160, 1070, -801, -1856, 2152, -79, 122, + -2790, -5641, -2021, -4328, 992, 664, 1078, 4919, + -5314, -665, -4650, -4734, 3417, -300, -3038, 6124, + -1161, -1786, -2922, 10536, 2726, 1200, -1840, 3752, + -3420, 1710, 2414, -2704, 918, 518, 1057, 1837, + 3098, 1665, 2780, 1636, -3883, -150, -3216, -5393, + 1819, -3555, -3063, -3252, -2948, 8249, -3856, -3981, + 406, -5407, -2135, 3006, -1920, -694, 1349, 2321, + -3114, -1262, -1296, -406, -712, 185, 1802, 62, + -1559, -62, 2270, -195, -1043, 2092, -3543, 1833, + 1193, 1880, 3076, 6353, 1671, -634, 3180, -21, + -612, 800, 6405, 2825, 1187, 583, -2961, -6221, + -1035, -1686, 3563, 7102, 7122, 3946, 3264, -2081, + 574, -2400, 22, 112, 1073, -2386, -3224, -3508, + -1347, -3521, 992, -2582, -7175, 1241, -1368, -6035, + -2555, -6012, -11198, -2492, -4061, -7604, -3521, -5613, + -3823, -6300, 6377, -6267, -3568, -1121, -2755, -6177, + 2627, -2735, -4447, -2327, -577, 824, 2159, -1206, + 47, -3988, -3918, -1073, -540, -595, 2777, -1114, + 985, 407, -1907, -3836, -7385, 9579, 120, 4717, + -1921, -5036, 1388, -2388, -1476, 2967, 2905, 3306, + -631, -1730, 4974, 51, -1131, -3307, -1678, -354, + 2481, -1133, 997, -1374, 2350, 1945, -274, -2238, + -1642, 869, 139, -2974, -1210, -362, 3461, -3912, + -7937, -1246, 5396, -6235, -6650, -9613, -5547, 2541, + -330, -2843, -3100, -227, 1859, 3371, 5094, 4045, + -8379, -2052, 363, 2005, 2248, 772, -872, 1686, + -3885, 1413, 704, -379, -1130, -703, -3406, 179, + 2895, 11203, -1085, -2496, -10569, 877, 2982, 4245, + 7216, -3703, 2468, 1361, -66, 236, -958, -3101, + 2424, -2604, 1854, -5674, 2951, -1898, 3078, 20, + 1217, -3799, 802, -458, -1522, -3094, -2448, -2067, + 658, -3163, 1976, -1577, -8063, 380, -1328, 5963, + -7396, -5218, -7379, -9166, -616, -1731, 2383, 3735, + 10889, -5348, 1128, -6396, -4613, -1547, 2619, -2967, + 2229, 3582, -156, -3970, -2606, -3270, 2515, -568, + -2800, -3145, -2641, 2530, 1079, 3184, -814, -1762, + 2128, -6864, 5163, -3934, 2410, 2574, 1568, -5281, + -1199, -2462, 713, -1456, 4651, -8439, -2239, -4620, + 316, 1772, 89, -2021, -658, -9442, -1249, -195, + -1311, -1129, 1734, 1991, 421, 579, 833, 2917, + 1025, -3243, -2909, 1950, -2845, 898, -1011, 5505, + 4705, 2989, -4835, -939, 3768, -1641, 10910, 34, + -938, 1839, 4835, -2526, -1699, -9939, 4135, 2330, + 746, -2420, 898, 588, -3496, -2904, -3896, 639, + 1046, 440, 1254, 2025, 2089, 3468, 697, 888, + 4553, 2152, 4522, 2916, 3432, 4376, -717, -8019, + 8063, -1602, -5389, -1549, 4541, 412, 413, -5267, + 5859, 147, 2962, 6490, -2794, 1448, -1348, -815, + -1089, -934, 1485, -1420, 827, -2345, -403, 2359, + -1298, 238, 1127, 1984, 3667, -6776, 1191, -1049, + 6323, 3381, 4703, 5709, 1693, -3948, -4716, 5403, + -3221, -1108, 478, -4250, 2643, 1458, -4684, -5321, + -1610, -1048, 4730, 1253, 1975, 1904, 2112, -1591, + -5355, 1317, -2438, 113, -1285, 4023, -1129, 3054, + -5091, 1484, -742, -1258, 1044, -1035, -442, 789, + 1525, 10987, -897, 2773, 357, 4770, 1942, 524, + 1315, 3575, -656, 1394, -14, -4854, 2764, 5455, + 1649, 1005, -1792, 1558, -1490, 3447, -1066, 662, + -974, -870, 1611, 2541, -2744, -1782, -1456, -820, + 261, -1722, -3869, -9244, 4372, 4013, -2733, -13592, + 5458, -6824, -634, 707, 742, 4432, -3446, -4348, + 916, 505, 3267, -9216, -3492, 2121, -4923, 4175, + -119, -1497, 1421, 3593, 1398, 273, 2351, 404 +}; + +static const opus_int16 layer1_recur_weights[432] = { + 381, -8053, -3581, -73, 5728, -10914, -4592, -14935, + 2526, -3600, 3424, 5804, -2523, 2785, -2245, 734, + 1045, -2857, 3888, -11398, 3406, -2679, 4999, -103, + 6707, -7102, 1158, -4524, 3212, 2065, -255, -4255, + 1682, -987, 333, 1958, 2943, -1600, 6811, 2103, + 4030, -4778, 5490, -11909, -1505, 3493, -9066, -3412, + -1673, -7387, -1995, 451, -2989, -2608, 317, 2076, + -6350, 4404, -1222, -3854, -4675, 12616, 3739, 126, + 1343, 8117, 620, -415, -1140, -931, -2678, -1561, + -1454, 1010, 1821, -1230, -3869, 3745, 2041, -1243, + -196, -4974, -9547, -6367, 3797, 105, -698, -1409, + -7030, 5843, -6749, -7885, -1051, 3730, -1202, 2938, + 1536, 2797, 4495, -309, 1954, 1637, 3972, 723, + 1782, 4101, 5525, -6803, 3625, 4203, -3680, -4308, + -5662, 2223, 1929, 1113, 7828, 61, -5548, -10833, + 8655, 3489, 3680, -829, -496, 6740, 1317, -1402, + 2411, 402, 1420, 1971, -3876, 4533, 4610, 6555, + 2928, -2090, -1689, 1243, 3253, 1051, 4787, -3870, + -2253, 4030, -507, 3956, -7122, 6049, 3373, 5868, + 782, 3961, -2132, -3936, 3944, -195, 1283, -382, + -141, 1447, 2272, 4714, 579, 3492, -2719, 937, + 3498, -5240, 3375, 3040, 290, -7514, -2126, -7146, + 3084, 1281, 4354, 338, 5197, -1488, 1623, 1854, + -2707, -2176, 3413, -2245, 851, 1715, -2870, 1309, + -1127, 662, -1673, 7551, -4901, -4459, 1943, -5998, + -4459, 1988, -1437, -6808, -530, 812, 6763, 1088, + -108, -547, -2758, 5672, 857, 2366, 1770, -3537, + -8239, 63, 6457, 3256, 2453, 5478, 3192, 4728, + -5188, -1048, -1468, 1944, -1620, -4830, 8233, 4379, + 887, -1339, 1825, 8806, -7448, 5491, 2284, 1983, + 4417, -50, -411, -1528, -609, 3553, -7104, 2208, + -4777, -877, -3517, 939, -5368, -7444, 4267, -994, + -3320, 3897, 1161, 3366, -6309, 6119, -3928, -2835, + 1384, -1238, 1558, -90, -1277, 3429, -2350, 929, + -7380, 705, -1443, -6141, -4110, 5939, 3391, -2137, + 222, 408, 619, 5516, 6060, 471, -2335, 31, + 636, -7196, 2346, -2082, 2530, -2093, 1603, -7208, + -6764, 2089, -10548, -3235, -3035, -9519, 5596, -5862, + -264, -514, -5881, 2064, 2158, -688, 1983, 9081, + -395, 1106, 1501, 506, -466, -3651, -879, 9723, + 5714, -1403, 3090, 2208, -127, -6849, -579, -1405, + 6088, -8262, -8095, -1043, -9232, -1771, -2790, -5700, + -1568, -1509, -1257, -2664, -1594, 560, -7664, -3712, + -971, 3808, -3434, -1332, -3769, -1509, 316, 3281, + 1581, -2888, -2234, -118, 919, 3520, 8085, -2894, + 1110, 12122, -1275, -2171, -1876, 8625, 1850, 1449, + 6177, 1800, 627, -5902, 3864, 4634, -3149, -1776, + 1389, 2766, 481, 2372, -71, 1265, -357, 1275, + -2011, 2432, 8081, 2382, 8879, 1983, -1742, -4043, + -361, 6496, 5009, -320, 4582, -2144, -4184, -1141, + -2661, -3733, -380, -1826, -17320, -3020, -11362, -10212, + -2959, -897, -2687, 1760, 2843, 836, -1765, 2219, + -3431, 298, 1666, -4254, 1589, -244, -745, -1628, + 1684, 2892, -4366, 2072, -6710, -1399, -8910, 2407 +}; + +static const opus_int16 layer1_bias[36] = { + 14206, 6258, 9052, 6611, -3603, 8785, 5625, 9775, + 6516, 4736, 8943, 3466, -888, -778, 5042, -3041, + 2719, 1724, 1216, 1698, 805, 2729, 1820, 4066, + -3456, 3091, 1570, 542, 599, 2583, 2052, 1258, + -2255, 1508, 1183, -5095 +}; + +static const opus_int16 layer2_weights[24] = { + 946, -14834, -5002, 14299, 10342, 1471, 7109, -508, + 11745, -1786, -621, 15227, -4577, 30114, 5174, 12698, + 22279, -527, 7727, 2246, 9892, -2297, -15579, 853 +}; -static const float weights[450] = { +static const opus_int16 layer2_bias[2] = { + 3700, 8418 +}; -/* hidden layer */ --0.514624f, 0.0234227f, -0.14329f, -0.0878216f, -0.00187827f, --0.0257443f, 0.108524f, 0.00333881f, 0.00585017f, -0.0246132f, -0.142723f, -0.00436494f, 0.0101354f, -0.11124f, -0.0809367f, --0.0750772f, 0.0295524f, 0.00823944f, 0.150392f, 0.0320876f, --0.0710564f, -1.43818f, 0.652076f, 0.0650744f, -1.54821f, -0.168949f, -1.92724f, 0.0517976f, -0.0670737f, -0.0690121f, -0.00247528f, -0.0522024f, 0.0631368f, 0.0532776f, 0.047751f, --0.011715f, 0.142374f, -0.0290885f, -0.279263f, -0.433499f, --0.0795174f, -0.380458f, -0.051263f, 0.218537f, -0.322478f, -1.06667f, -0.104607f, -4.70108f, 0.312037f, 0.277397f, --2.71859f, 1.70037f, -0.141845f, 0.0115618f, 0.0629883f, -0.0403871f, 0.0139428f, -0.00430733f, -0.0429038f, -0.0590318f, --0.0501526f, -0.0284802f, -0.0415686f, -0.0438999f, 0.0822666f, -0.197194f, 0.0363275f, -0.0584307f, 0.0752364f, -0.0799796f, --0.146275f, 0.161661f, -0.184585f, 0.145568f, 0.442823f, -1.61221f, 1.11162f, 2.62177f, -2.482f, -0.112599f, --0.110366f, -0.140794f, -0.181694f, 0.0648674f, 0.0842248f, -0.0933993f, 0.150122f, 0.129171f, 0.176848f, 0.141758f, --0.271822f, 0.235113f, 0.0668579f, -0.433957f, 0.113633f, --0.169348f, -1.40091f, 0.62861f, -0.134236f, 0.402173f, -1.86373f, 1.53998f, -4.32084f, 0.735343f, 0.800214f, --0.00968415f, 0.0425904f, 0.0196811f, -0.018426f, -0.000343953f, --0.00416389f, 0.00111558f, 0.0173069f, -0.00998596f, -0.025898f, -0.00123764f, -0.00520373f, -0.0565033f, 0.0637394f, 0.0051213f, -0.0221361f, 0.00819962f, -0.0467061f, -0.0548258f, -0.00314063f, --1.18332f, 1.88091f, -0.41148f, -2.95727f, -0.521449f, --0.271641f, 0.124946f, -0.0532936f, 0.101515f, 0.000208564f, --0.0488748f, 0.0642388f, -0.0383848f, 0.0135046f, -0.0413592f, --0.0326402f, -0.0137421f, -0.0225219f, -0.0917294f, -0.277759f, --0.185418f, 0.0471128f, -0.125879f, 0.262467f, -0.212794f, --0.112931f, -1.99885f, -0.404787f, 0.224402f, 0.637962f, --0.27808f, -0.0723953f, -0.0537655f, -0.0336359f, -0.0906601f, --0.0641309f, -0.0713542f, 0.0524317f, 0.00608819f, 0.0754101f, --0.0488401f, -0.00671865f, 0.0418239f, 0.0536284f, -0.132639f, -0.0267648f, -0.248432f, -0.0104153f, 0.035544f, -0.212753f, --0.302895f, -0.0357854f, 0.376838f, 0.597025f, -0.664647f, -0.268422f, -0.376772f, -1.05472f, 0.0144178f, 0.179122f, -0.0360155f, 0.220262f, -0.0056381f, 0.0317197f, 0.0621066f, --0.00779298f, 0.00789378f, 0.00350605f, 0.0104809f, 0.0362871f, --0.157708f, -0.0659779f, -0.0926278f, 0.00770791f, 0.0631621f, -0.0817343f, -0.424295f, -0.0437727f, -0.24251f, 0.711217f, --0.736455f, -2.194f, -0.107612f, -0.175156f, -0.0366573f, --0.0123156f, -0.0628516f, -0.0218977f, -0.00693699f, 0.00695185f, -0.00507362f, 0.00359334f, 0.0052661f, 0.035561f, 0.0382701f, -0.0342179f, -0.00790271f, -0.0170925f, 0.047029f, 0.0197362f, --0.0153435f, 0.0644152f, -0.36862f, -0.0674876f, -2.82672f, -1.34122f, -0.0788029f, -3.47792f, 0.507246f, -0.816378f, --0.0142383f, -0.127349f, -0.106926f, -0.0359524f, 0.105045f, -0.291554f, 0.195413f, 0.0866214f, -0.066577f, -0.102188f, -0.0979466f, -0.12982f, 0.400181f, -0.409336f, -0.0593326f, --0.0656203f, -0.204474f, 0.179802f, 0.000509084f, 0.0995954f, --2.377f, -0.686359f, 0.934861f, 1.10261f, 1.3901f, --4.33616f, -0.00264017f, 0.00713045f, 0.106264f, 0.143726f, --0.0685305f, -0.054656f, -0.0176725f, -0.0772669f, -0.0264526f, --0.0103824f, -0.0269872f, -0.00687f, 0.225804f, 0.407751f, --0.0612611f, -0.0576863f, -0.180131f, -0.222772f, -0.461742f, -0.335236f, 1.03399f, 4.24112f, -0.345796f, -0.594549f, --76.1407f, -0.265276f, 0.0507719f, 0.0643044f, 0.0384832f, -0.0424459f, -0.0387817f, -0.0235996f, -0.0740556f, -0.0270029f, -0.00882177f, -0.0552371f, -0.00485851f, 0.314295f, 0.360431f, --0.0787085f, 0.110355f, -0.415958f, -0.385088f, -0.272224f, --1.55108f, -0.141848f, 0.448877f, -0.563447f, -2.31403f, --0.120077f, -1.49918f, -0.817726f, -0.0495854f, -0.0230782f, --0.0224014f, 0.117076f, 0.0393216f, 0.051997f, 0.0330763f, --0.110796f, 0.0211117f, -0.0197258f, 0.0187461f, 0.0125183f, -0.14876f, 0.0920565f, -0.342475f, 0.135272f, -0.168155f, --0.033423f, -0.0604611f, -0.128835f, 0.664947f, -0.144997f, -2.27649f, 1.28663f, 0.841217f, -2.42807f, 0.0230471f, -0.226709f, -0.0374803f, 0.155436f, 0.0400342f, -0.184686f, -0.128488f, -0.0939518f, -0.0578559f, 0.0265967f, -0.0999322f, --0.0322768f, -0.322994f, -0.189371f, -0.738069f, -0.0754914f, -0.214717f, -0.093728f, -0.695741f, 0.0899298f, -2.06188f, --0.273719f, -0.896977f, 0.130553f, 0.134638f, 1.29355f, -0.00520749f, -0.0324224f, 0.00530451f, 0.0192385f, 0.00328708f, -0.0250838f, 0.0053365f, -0.0177321f, 0.00618789f, 0.00525364f, -0.00104596f, -0.0360459f, 0.0402403f, -0.0406351f, 0.0136883f, -0.0880722f, -0.0197449f, 0.089938f, 0.0100456f, -0.0475638f, --0.73267f, 0.037433f, -0.146551f, -0.230221f, -3.06489f, --1.40194f, 0.0198483f, 0.0397953f, -0.0190239f, 0.0470715f, --0.131363f, -0.191721f, -0.0176224f, -0.0480352f, -0.221799f, --0.26794f, -0.0292615f, 0.0612127f, -0.129877f, 0.00628332f, --0.085918f, 0.0175379f, 0.0541011f, -0.0810874f, -0.380809f, --0.222056f, -0.508859f, -0.473369f, 0.484958f, -2.28411f, -0.0139516f, -/* output layer */ -3.90017f, 1.71789f, -1.43372f, -2.70839f, 1.77107f, -5.48006f, 1.44661f, 2.01134f, -1.88383f, -3.64958f, --1.26351f, 0.779421f, 2.11357f, 3.10409f, 1.68846f, --4.46197f, -1.61455f, 3.59832f, 2.43531f, -1.26458f, -0.417941f, 1.47437f, 2.16635f, -1.909f, -0.828869f, -1.38805f, -2.67975f, -0.110044f, 1.95596f, 0.697931f, --0.313226f, -0.889315f, 0.283236f, 0.946102f, }; +const DenseLayer layer0 = { + layer0_bias, + layer0_weights, + 25, 16, 0 +}; -static const int topo[3] = {25, 16, 2}; +const GRULayer layer1 = { + layer1_bias, + layer1_weights, + layer1_recur_weights, + 16, 12 +}; -const MLP net = { - 3, - topo, - weights +const DenseLayer layer2 = { + layer2_bias, + layer2_weights, + 12, 2, 1 }; + diff --git a/src/mlp_train.c b/src/mlp_train.c deleted file mode 100644 index 8d9d127a..00000000 --- a/src/mlp_train.c +++ /dev/null @@ -1,501 +0,0 @@ -/* Copyright (c) 2008-2011 Octasic Inc. - Written by Jean-Marc Valin */ -/* - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR - CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ - - -#include "mlp_train.h" -#include <stdlib.h> -#include <stdio.h> -#include <string.h> -#include <semaphore.h> -#include <pthread.h> -#include <time.h> -#include <signal.h> - -int stopped = 0; - -void handler(int sig) -{ - stopped = 1; - signal(sig, handler); -} - -MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int nbSamples) -{ - int i, j, k; - MLPTrain *net; - int inDim, outDim; - net = malloc(sizeof(*net)); - net->topo = malloc(nbLayers*sizeof(net->topo[0])); - for (i=0;i<nbLayers;i++) - net->topo[i] = topo[i]; - inDim = topo[0]; - outDim = topo[nbLayers-1]; - net->in_rate = malloc((inDim+1)*sizeof(net->in_rate[0])); - net->weights = malloc((nbLayers-1)*sizeof(net->weights)); - net->best_weights = malloc((nbLayers-1)*sizeof(net->weights)); - for (i=0;i<nbLayers-1;i++) - { - net->weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0])); - net->best_weights[i] = malloc((topo[i]+1)*topo[i+1]*sizeof(net->weights[0][0])); - } - double inMean[inDim]; - for (j=0;j<inDim;j++) - { - double std=0; - inMean[j] = 0; - for (i=0;i<nbSamples;i++) - { - inMean[j] += inputs[i*inDim+j]; - std += inputs[i*inDim+j]*inputs[i*inDim+j]; - } - inMean[j] /= nbSamples; - std /= nbSamples; - net->in_rate[1+j] = .5/(.0001+std); - std = std-inMean[j]*inMean[j]; - if (std<.001) - std = .001; - std = 1/sqrt(inDim*std); - for (k=0;k<topo[1];k++) - net->weights[0][k*(topo[0]+1)+j+1] = randn(std); - } - net->in_rate[0] = 1; - for (j=0;j<topo[1];j++) - { - double sum = 0; - for (k=0;k<inDim;k++) - sum += inMean[k]*net->weights[0][j*(topo[0]+1)+k+1]; - net->weights[0][j*(topo[0]+1)] = -sum; - } - for (j=0;j<outDim;j++) - { - double mean = 0; - double std; - for (i=0;i<nbSamples;i++) - mean += outputs[i*outDim+j]; - mean /= nbSamples; - std = 1/sqrt(topo[nbLayers-2]); - net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)] = mean; - for (k=0;k<topo[nbLayers-2];k++) - net->weights[nbLayers-2][j*(topo[nbLayers-2]+1)+k+1] = randn(std); - } - return net; -} - -#define MAX_NEURONS 100 -#define MAX_OUT 10 - -double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate) -{ - int i,j; - int s; - int inDim, outDim, hiddenDim; - int *topo; - double *W0, *W1; - double rms=0; - int W0_size, W1_size; - double hidden[MAX_NEURONS]; - double netOut[MAX_NEURONS]; - double error[MAX_NEURONS]; - - topo = net->topo; - inDim = net->topo[0]; - hiddenDim = net->topo[1]; - outDim = net->topo[2]; - W0_size = (topo[0]+1)*topo[1]; - W1_size = (topo[1]+1)*topo[2]; - W0 = net->weights[0]; - W1 = net->weights[1]; - memset(W0_grad, 0, W0_size*sizeof(double)); - memset(W1_grad, 0, W1_size*sizeof(double)); - for (i=0;i<outDim;i++) - netOut[i] = outputs[i]; - for (i=0;i<outDim;i++) - error_rate[i] = 0; - for (s=0;s<nbSamples;s++) - { - float *in, *out; - float inp[inDim]; - in = inputs+s*inDim; - out = outputs + s*outDim; - for (j=0;j<inDim;j++) - inp[j] = in[j]; - for (i=0;i<hiddenDim;i++) - { - double sum = W0[i*(inDim+1)]; - for (j=0;j<inDim;j++) - sum += W0[i*(inDim+1)+j+1]*inp[j]; - hidden[i] = tansig_approx(sum); - } - for (i=0;i<outDim;i++) - { - double sum = W1[i*(hiddenDim+1)]; - for (j=0;j<hiddenDim;j++) - sum += W1[i*(hiddenDim+1)+j+1]*hidden[j]; - netOut[i] = tansig_approx(sum); - error[i] = out[i] - netOut[i]; - if (out[i] == 0) error[i] *= .0; - error_rate[i] += fabs(error[i])>1; - if (i==0) error[i] *= 5; - rms += error[i]*error[i]; - /*error[i] = error[i]/(1+fabs(error[i]));*/ - } - /* Back-propagate error */ - for (i=0;i<outDim;i++) - { - double grad = 1-netOut[i]*netOut[i]; - W1_grad[i*(hiddenDim+1)] += error[i]*grad; - for (j=0;j<hiddenDim;j++) - W1_grad[i*(hiddenDim+1)+j+1] += grad*error[i]*hidden[j]; - } - for (i=0;i<hiddenDim;i++) - { - double grad; - grad = 0; - for (j=0;j<outDim;j++) - grad += error[j]*W1[j*(hiddenDim+1)+i+1]; - grad *= 1-hidden[i]*hidden[i]; - W0_grad[i*(inDim+1)] += grad; - for (j=0;j<inDim;j++) - W0_grad[i*(inDim+1)+j+1] += grad*inp[j]; - } - } - return rms; -} - -#define NB_THREADS 8 - -sem_t sem_begin[NB_THREADS]; -sem_t sem_end[NB_THREADS]; - -struct GradientArg { - int id; - int done; - MLPTrain *net; - float *inputs; - float *outputs; - int nbSamples; - double *W0_grad; - double *W1_grad; - double rms; - double error_rate[MAX_OUT]; -}; - -void *gradient_thread_process(void *_arg) -{ - int W0_size, W1_size; - struct GradientArg *arg = _arg; - int *topo = arg->net->topo; - W0_size = (topo[0]+1)*topo[1]; - W1_size = (topo[1]+1)*topo[2]; - double W0_grad[W0_size]; - double W1_grad[W1_size]; - arg->W0_grad = W0_grad; - arg->W1_grad = W1_grad; - while (1) - { - sem_wait(&sem_begin[arg->id]); - if (arg->done) - break; - arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate); - sem_post(&sem_end[arg->id]); - } - fprintf(stderr, "done\n"); - return NULL; -} - -float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSamples, int nbEpoch, float rate) -{ - int i, j; - int e; - float best_rms = 1e10; - int inDim, outDim, hiddenDim; - int *topo; - double *W0, *W1, *best_W0, *best_W1; - double *W0_grad, *W1_grad; - double *W0_oldgrad, *W1_oldgrad; - double *W0_rate, *W1_rate; - double *best_W0_rate, *best_W1_rate; - int W0_size, W1_size; - topo = net->topo; - W0_size = (topo[0]+1)*topo[1]; - W1_size = (topo[1]+1)*topo[2]; - struct GradientArg args[NB_THREADS]; - pthread_t thread[NB_THREADS]; - int samplePerPart = nbSamples/NB_THREADS; - int count_worse=0; - int count_retries=0; - - topo = net->topo; - inDim = net->topo[0]; - hiddenDim = net->topo[1]; - outDim = net->topo[2]; - W0 = net->weights[0]; - W1 = net->weights[1]; - best_W0 = net->best_weights[0]; - best_W1 = net->best_weights[1]; - W0_grad = malloc(W0_size*sizeof(double)); - W1_grad = malloc(W1_size*sizeof(double)); - W0_oldgrad = malloc(W0_size*sizeof(double)); - W1_oldgrad = malloc(W1_size*sizeof(double)); - W0_rate = malloc(W0_size*sizeof(double)); - W1_rate = malloc(W1_size*sizeof(double)); - best_W0_rate = malloc(W0_size*sizeof(double)); - best_W1_rate = malloc(W1_size*sizeof(double)); - memset(W0_grad, 0, W0_size*sizeof(double)); - memset(W0_oldgrad, 0, W0_size*sizeof(double)); - memset(W1_grad, 0, W1_size*sizeof(double)); - memset(W1_oldgrad, 0, W1_size*sizeof(double)); - - rate /= nbSamples; - for (i=0;i<hiddenDim;i++) - for (j=0;j<inDim+1;j++) - W0_rate[i*(inDim+1)+j] = rate*net->in_rate[j]; - for (i=0;i<W1_size;i++) - W1_rate[i] = rate; - - for (i=0;i<NB_THREADS;i++) - { - args[i].net = net; - args[i].inputs = inputs+i*samplePerPart*inDim; - args[i].outputs = outputs+i*samplePerPart*outDim; - args[i].nbSamples = samplePerPart; - args[i].id = i; - args[i].done = 0; - sem_init(&sem_begin[i], 0, 0); - sem_init(&sem_end[i], 0, 0); - pthread_create(&thread[i], NULL, gradient_thread_process, &args[i]); - } - for (e=0;e<nbEpoch;e++) - { - double rms=0; - double error_rate[2] = {0,0}; - for (i=0;i<NB_THREADS;i++) - { - sem_post(&sem_begin[i]); - } - memset(W0_grad, 0, W0_size*sizeof(double)); - memset(W1_grad, 0, W1_size*sizeof(double)); - for (i=0;i<NB_THREADS;i++) - { - sem_wait(&sem_end[i]); - rms += args[i].rms; - error_rate[0] += args[i].error_rate[0]; - error_rate[1] += args[i].error_rate[1]; - for (j=0;j<W0_size;j++) - W0_grad[j] += args[i].W0_grad[j]; - for (j=0;j<W1_size;j++) - W1_grad[j] += args[i].W1_grad[j]; - } - - float mean_rate = 0, min_rate = 1e10; - rms = (rms/(outDim*nbSamples)); - error_rate[0] = (error_rate[0]/(nbSamples)); - error_rate[1] = (error_rate[1]/(nbSamples)); - fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms); - if (rms < best_rms) - { - best_rms = rms; - for (i=0;i<W0_size;i++) - { - best_W0[i] = W0[i]; - best_W0_rate[i] = W0_rate[i]; - } - for (i=0;i<W1_size;i++) - { - best_W1[i] = W1[i]; - best_W1_rate[i] = W1_rate[i]; - } - count_worse=0; - count_retries=0; - } else { - count_worse++; - if (count_worse>30) - { - count_retries++; - count_worse=0; - for (i=0;i<W0_size;i++) - { - W0[i] = best_W0[i]; - best_W0_rate[i] *= .7; - if (best_W0_rate[i]<1e-15) best_W0_rate[i]=1e-15; - W0_rate[i] = best_W0_rate[i]; - W0_grad[i] = 0; - } - for (i=0;i<W1_size;i++) - { - W1[i] = best_W1[i]; - best_W1_rate[i] *= .8; - if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15; - W1_rate[i] = best_W1_rate[i]; - W1_grad[i] = 0; - } - } - } - if (count_retries>10) - break; - for (i=0;i<W0_size;i++) - { - if (W0_oldgrad[i]*W0_grad[i] > 0) - W0_rate[i] *= 1.01; - else if (W0_oldgrad[i]*W0_grad[i] < 0) - W0_rate[i] *= .9; - mean_rate += W0_rate[i]; - if (W0_rate[i] < min_rate) - min_rate = W0_rate[i]; - if (W0_rate[i] < 1e-15) - W0_rate[i] = 1e-15; - /*if (W0_rate[i] > .01) - W0_rate[i] = .01;*/ - W0_oldgrad[i] = W0_grad[i]; - W0[i] += W0_grad[i]*W0_rate[i]; - } - for (i=0;i<W1_size;i++) - { - if (W1_oldgrad[i]*W1_grad[i] > 0) - W1_rate[i] *= 1.01; - else if (W1_oldgrad[i]*W1_grad[i] < 0) - W1_rate[i] *= .9; - mean_rate += W1_rate[i]; - if (W1_rate[i] < min_rate) - min_rate = W1_rate[i]; - if (W1_rate[i] < 1e-15) - W1_rate[i] = 1e-15; - W1_oldgrad[i] = W1_grad[i]; - W1[i] += W1_grad[i]*W1_rate[i]; - } - mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]; - fprintf (stderr, "%g %d", mean_rate, e); - if (count_retries) - fprintf(stderr, " %d", count_retries); - fprintf(stderr, "\n"); - if (stopped) - break; - } - for (i=0;i<NB_THREADS;i++) - { - args[i].done = 1; - sem_post(&sem_begin[i]); - pthread_join(thread[i], NULL); - fprintf (stderr, "joined %d\n", i); - } - free(W0_grad); - free(W0_oldgrad); - free(W1_grad); - free(W1_oldgrad); - free(W0_rate); - free(best_W0_rate); - free(W1_rate); - free(best_W1_rate); - return best_rms; -} - -int main(int argc, char **argv) -{ - int i, j; - int nbInputs; - int nbOutputs; - int nbHidden; - int nbSamples; - int nbEpoch; - int nbRealInputs; - unsigned int seed; - int ret; - float rms; - float *inputs; - float *outputs; - if (argc!=6) - { - fprintf (stderr, "usage: mlp_train <inputs> <hidden> <outputs> <nb samples> <nb epoch>\n"); - return 1; - } - nbInputs = atoi(argv[1]); - nbHidden = atoi(argv[2]); - nbOutputs = atoi(argv[3]); - nbSamples = atoi(argv[4]); - nbEpoch = atoi(argv[5]); - nbRealInputs = nbInputs; - inputs = malloc(nbInputs*nbSamples*sizeof(*inputs)); - outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs)); - - seed = time(NULL); - /*seed = 1452209040;*/ - fprintf (stderr, "Seed is %u\n", seed); - srand(seed); - build_tansig_table(); - signal(SIGTERM, handler); - signal(SIGINT, handler); - signal(SIGHUP, handler); - for (i=0;i<nbSamples;i++) - { - for (j=0;j<nbRealInputs;j++) - ret = scanf(" %f", &inputs[i*nbInputs+j]); - for (j=0;j<nbOutputs;j++) - ret = scanf(" %f", &outputs[i*nbOutputs+j]); - if (feof(stdin)) - { - nbSamples = i; - break; - } - } - int topo[3] = {nbInputs, nbHidden, nbOutputs}; - MLPTrain *net; - - fprintf (stderr, "Got %d samples\n", nbSamples); - net = mlp_init(topo, 3, inputs, outputs, nbSamples); - rms = mlp_train_backprop(net, inputs, outputs, nbSamples, nbEpoch, 1); - printf ("#ifdef HAVE_CONFIG_H\n"); - printf ("#include \"config.h\"\n"); - printf ("#endif\n\n"); - printf ("#include \"mlp.h\"\n\n"); - printf ("/* RMS error was %f, seed was %u */\n\n", rms, seed); - printf ("static const float weights[%d] = {\n", (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2]); - printf ("\n/* hidden layer */\n"); - for (i=0;i<(topo[0]+1)*topo[1];i++) - { - printf ("%gf,", net->weights[0][i]); - if (i%5==4) - printf("\n"); - else - printf(" "); - } - printf ("\n/* output layer */\n"); - for (i=0;i<(topo[1]+1)*topo[2];i++) - { - printf ("%gf,", net->weights[1][i]); - if (i%5==4) - printf("\n"); - else - printf(" "); - } - printf ("};\n\n"); - printf ("static const int topo[3] = {%d, %d, %d};\n\n", topo[0], topo[1], topo[2]); - printf ("const MLP net = {\n"); - printf (" 3,\n"); - printf (" topo,\n"); - printf (" weights\n};\n"); - return 0; -} diff --git a/src/mlp_train.h b/src/mlp_train.h deleted file mode 100644 index 49404158..00000000 --- a/src/mlp_train.h +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright (c) 2008-2011 Octasic Inc. - Written by Jean-Marc Valin */ -/* - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR - CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ - -#ifndef _MLP_TRAIN_H_ -#define _MLP_TRAIN_H_ - -#include <math.h> -#include <stdlib.h> - -double tansig_table[501]; -static inline double tansig_double(double x) -{ - return 2./(1.+exp(-2.*x)) - 1.; -} -static inline void build_tansig_table(void) -{ - int i; - for (i=0;i<501;i++) - tansig_table[i] = tansig_double(.04*(i-250)); -} - -static inline double tansig_approx(double x) -{ - int i; - double y, dy; - if (x>=10) - return 1; - if (x<=-10) - return -1; - i = lrint(25*x); - x -= .04*i; - y = tansig_table[250+i]; - dy = 1-y*y; - y = y + x*dy*(1 - y*x); - return y; -} - -static inline float randn(float sd) -{ - float U1, U2, S, x; - do { - U1 = ((float)rand())/RAND_MAX; - U2 = ((float)rand())/RAND_MAX; - U1 = 2*U1-1; - U2 = 2*U2-1; - S = U1*U1 + U2*U2; - } while (S >= 1 || S == 0.0f); - x = sd*sqrt(-2 * log(S) / S) * U1; - return x; -} - - -typedef struct { - int layers; - int *topo; - double **weights; - double **best_weights; - double *in_rate; -} MLPTrain; - - -#endif /* _MLP_TRAIN_H_ */ diff --git a/src/opus_encoder.c b/src/opus_encoder.c index 3770fc64..0494170f 100644 --- a/src/opus_encoder.c +++ b/src/opus_encoder.c @@ -1189,7 +1189,16 @@ opus_int32 opus_encode_native(OpusEncoder *st, const opus_val16 *pcm, int frame_ { int analysis_bandwidth; if (st->signal_type == OPUS_AUTO) - st->voice_ratio = (int)floor(.5+100*(1-analysis_info.music_prob)); + { + float prob; + if (st->prev_mode == 0) + prob = analysis_info.music_prob; + else if (st->prev_mode == MODE_CELT_ONLY) + prob = analysis_info.music_prob_max; + else + prob = analysis_info.music_prob_min; + st->voice_ratio = (int)floor(.5+100*(1-prob)); + } analysis_bandwidth = analysis_info.bandwidth; if (analysis_bandwidth<=12) |