diff options
author | Gregor Richards <hg-yff@gregor.im> | 2018-08-28 17:40:28 +0300 |
---|---|---|
committer | Jean-Marc Valin <jmvalin@jmvalin.ca> | 2019-05-29 07:37:07 +0300 |
commit | f30741bed8495e164049a495de89ac417f27ccf0 (patch) | |
tree | 69caa4533480771bce88be110149dd78ce8aa431 | |
parent | bfba2ad7a4a419383e839d661df0e69eb0f592e5 (diff) |
Made dump_rnn output in nu format.
-rwxr-xr-x | training/dump_rnn.py | 42 |
1 files changed, 23 insertions, 19 deletions
diff --git a/training/dump_rnn.py b/training/dump_rnn.py index 9f267a7..a9931b7 100755 --- a/training/dump_rnn.py +++ b/training/dump_rnn.py @@ -30,7 +30,7 @@ def printVector(f, vector, name): f.write('\n};\n\n') return; -def printLayer(f, hf, layer): +def printLayer(f, layer): weights = layer.get_weights() printVector(f, weights[0], layer.name + '_weights') if len(weights) > 2: @@ -39,19 +39,24 @@ def printLayer(f, hf, layer): name = layer.name activation = re.search('function (.*) at', str(layer.activation)).group(1).upper() if len(weights) > 2: - f.write('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' + f.write('static const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' .format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation)) - hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1]/3)) - hf.write('extern const GRULayer {};\n\n'.format(name)); else: - f.write('const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' + f.write('static const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' .format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation)) - hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1])) - hf.write('extern const DenseLayer {};\n\n'.format(name)); + +def structLayer(f, layer): + weights = layer.get_weights() + name = layer.name + if len(weights) > 2: + f.write(' {},\n'.format(weights[0].shape[1]/3)) + else: + f.write(' {},\n'.format(weights[0].shape[1])) + f.write(' &{},\n'.format(name)) def foo(c, name): - return 1 + return None def mean_squared_sqrt_error(y_true, y_pred): return K.mean(K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1) @@ -62,27 +67,26 @@ model = load_model(sys.argv[1], custom_objects={'msse': mean_squared_sqrt_error, weights = model.get_weights() f = open(sys.argv[2], 'w') -hf = open(sys.argv[3], 'w') f.write('/*This file is automatically generated from a Keras model*/\n\n') f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n\n') -hf.write('/*This file is automatically generated from a Keras model*/\n\n') -hf.write('#ifndef RNN_DATA_H\n#define RNN_DATA_H\n\n#include "rnn.h"\n\n') - layer_list = [] for i, layer in enumerate(model.layers): if len(layer.get_weights()) > 0: - printLayer(f, hf, layer) + printLayer(f, layer) if len(layer.get_weights()) > 2: layer_list.append(layer.name) -hf.write('struct RNNState {\n') -for i, name in enumerate(layer_list): - hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper())) -hf.write('};\n') +f.write('const struct RNNModel rnnoise_model_{} = {{\n'.format(sys.argv[3])) +for i, layer in enumerate(model.layers): + if len(layer.get_weights()) > 0: + structLayer(f, layer) +f.write('};\n') -hf.write('\n\n#endif\n') +#hf.write('struct RNNState {\n') +#for i, name in enumerate(layer_list): +# hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper())) +#hf.write('};\n') f.close() -hf.close() |