Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-09-29 16:34:59 +0300
committerJan Buethe <jbuethe@amazon.de>2023-09-29 16:34:59 +0300
commit0459a572f592fb07376c480c1ebbf04c16090211 (patch)
tree5c73f8562b7a22891763002666eaa3a979f90341
parentce28695844c12f43c31b4ee739749883c8b44b17 (diff)
updated PitchDNN export script
-rw-r--r--dnn/torch/neural-pitch/export_neuralpitch_weights.py46
1 files changed, 29 insertions, 17 deletions
diff --git a/dnn/torch/neural-pitch/export_neuralpitch_weights.py b/dnn/torch/neural-pitch/export_neuralpitch_weights.py
index a56784a9..9f20ec9e 100644
--- a/dnn/torch/neural-pitch/export_neuralpitch_weights.py
+++ b/dnn/torch/neural-pitch/export_neuralpitch_weights.py
@@ -44,7 +44,7 @@ args = parser.parse_args()
import torch
import numpy as np
-from models import large_if_ccode
+from models import PitchDNN
from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector
@@ -52,39 +52,51 @@ def c_export(args, model):
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
- enc_writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='nnpitch')
- enc_writer.header.write(
+ writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='PitchDNN')
+ writer.header.write(
f"""
#include "opus_types.h"
"""
)
-
- # encoder
- encoder_dense_layers = [
- ('initial' , 'initial', 'TANH'),
- ('upsample' , 'upsample', 'TANH')
+ layers = [
+ ('if_upsample.0', "dense_if_upsampler_1"),
+ ('if_upsample.2', "dense_if_upsampler_2"),
+ ('conv.1', "conv2d_1"),
+ ('conv.4', "conv2d_2"),
+ ('conv.7', "conv2d_3"),
+ ('downsample.0', "dense_downsampler"),
+ ("upsample.0", "dense_final_upsampler")
]
- for name, export_name, _ in encoder_dense_layers:
+
+ for name, export_name in layers:
layer = model.get_submodule(name)
- dump_torch_weights(enc_writer, layer, name=export_name, verbose=True)
+ dump_torch_weights(writer, layer, name=export_name, verbose=True)
- encoder_gru_layers = [
- ('gru' , 'gru', 'TANH'),
+ gru_layers = [
+ ("GRU", "gru_1"),
]
- enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
- for name, export_name, _ in encoder_gru_layers])
+ max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
+ for name, export_name in gru_layers])
+
+ writer.header.write(
+f"""
+
+#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units}
+
+"""
+ )
- del enc_writer
+ writer.close()
if __name__ == "__main__":
os.makedirs(args.output_dir, exist_ok=True)
- model = large_if_ccode()
- checkpoint = torch.load(args.checkpoint ,map_location='cpu')
+ model = PitchDNN()
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
c_export(args, model)