Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jmvalin@jmvalin.ca>2024-01-17 10:26:48 +0300
committerJean-Marc Valin <jmvalin@jmvalin.ca>2024-01-17 10:26:48 +0300
commit4f311a1ad44f1b7bd60e32984ca0604c46b6c593 (patch)
tree7bc6041a00e98dd1ff926253e68cffb2c32ece6f
parent26ddfd713537accce773acc12f565021f4f6d28c (diff)
PLC export script
mostly untested
-rw-r--r--dnn/torch/plc/export_plc.py100
-rw-r--r--dnn/torch/plc/train_plc.py2
2 files changed, 101 insertions, 1 deletions
diff --git a/dnn/torch/plc/export_plc.py b/dnn/torch/plc/export_plc.py
new file mode 100644
index 00000000..7f153c4c
--- /dev/null
+++ b/dnn/torch/plc/export_plc.py
@@ -0,0 +1,100 @@
+"""
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+import sys
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('checkpoint', type=str, help='model checkpoint')
+parser.add_argument('output_dir', type=str, help='output folder')
+
+args = parser.parse_args()
+
+import torch
+import numpy as np
+
+import plc
+from wexchange.torch import dump_torch_weights
+from wexchange.c_export import CWriter, print_vector
+
+def c_export(args, model):
+
+ message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
+
+ writer = CWriter(os.path.join(args.output_dir, "plc_data"), message=message, model_struct_name='PLCModel')
+ writer.header.write(
+f"""
+#include "opus_types.h"
+"""
+ )
+
+ dense_layers = [
+ ('dense_in', "plc_dense_in"),
+ ('dense_out', "plc_dense_out")
+ ]
+
+
+ for name, export_name in dense_layers:
+ layer = model.get_submodule(name)
+ dump_torch_weights(writer, layer, name=export_name, verbose=True, quantize=False, scale=None)
+
+
+ gru_layers = [
+ ("gru1", "plc_gru1"),
+ ("gru2", "plc_gru2"),
+ ]
+
+ max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=True, scale=None, recurrent_scale=None)
+ for name, export_name in gru_layers])
+
+ writer.header.write(
+f"""
+
+#define PLC_MAX_RNN_UNITS {max_rnn_units}
+
+"""
+ )
+
+ writer.close()
+
+
+if __name__ == "__main__":
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
+ model = plc.PLC(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
+ #checkpoint = torch.load(args.checkpoint, map_location='cpu')
+ #model.load_state_dict(checkpoint['state_dict'])
+ c_export(args, model)
diff --git a/dnn/torch/plc/train_plc.py b/dnn/torch/plc/train_plc.py
index 97be2c04..12b31c4e 100644
--- a/dnn/torch/plc/train_plc.py
+++ b/dnn/torch/plc/train_plc.py
@@ -138,7 +138,7 @@ if __name__ == '__main__':
)
# save checkpoint
- checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_{epoch}.pth')
+ checkpoint_path = os.path.join(checkpoint_dir, f'plc{args.suffix}_{epoch}.pth')
checkpoint['state_dict'] = model.state_dict()
checkpoint['loss'] = running_loss / len(dataloader)
checkpoint['epoch'] = epoch