Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-11-07 17:12:12 +0300
committerJan Buethe <jbuethe@amazon.de>2023-11-07 17:12:12 +0300
commit4e104555e98c8227464f02ee388d983d387612b6 (patch)
tree8b2bb038c97ee1dea0fb7c1ae1e560fb3ed57ac8
parent8af5c6b4a13cb66e0f3dcd465c246d2d2e4128c7 (diff)
added weight export script for LACE/NoLACE
-rw-r--r--dnn/torch/osce/export_model_weights.py97
1 files changed, 97 insertions, 0 deletions
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py
new file mode 100644
index 00000000..8b95aca9
--- /dev/null
+++ b/dnn/torch/osce/export_model_weights.py
@@ -0,0 +1,97 @@
+"""
+/* Copyright (c) 2023 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+import sys
+
+import hashlib
+
+sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
+
+import torch
+import wexchange.torch
+from wexchange.torch import dump_torch_weights
+from models import model_dict
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint')
+parser.add_argument('output_dir', type=str, help='output folder')
+
+
+# auxiliary functions
+def sha1(filename):
+ BUF_SIZE = 65536
+ sha1 = hashlib.sha1()
+
+ with open(filename, 'rb') as f:
+ while True:
+ data = f.read(BUF_SIZE)
+ if not data:
+ break
+ sha1.update(data)
+
+ return sha1.hexdigest()
+
+def export_name(name):
+ return name.replace('.', '_')
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+
+ checkpoint_path = args.checkpoint
+ outdir = args.output_dir
+ os.makedirs(outdir, exist_ok=True)
+
+ # dump message
+ message = f"Auto generated from checkpoint {os.path.basename(checkpoint_path)} (sha1: {sha1(checkpoint_path)})"
+
+ # create model and load weights
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+ model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
+
+ # CWriter
+ model_name = checkpoint['setup']['model']['name']
+ cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper())
+
+ # dump numbits_embedding parameters by hand
+ numbits_embedding = model.get_submodule('numbits_embedding')
+ weights = next(iter(numbits_embedding.parameters()))
+ for i, c in enumerate(weights):
+ cwriter.header.write(f"\nNUMBITS_COEF_{i} {float(c.detach())}f")
+ cwriter.header.write("\n\n")
+
+ # dump layers
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv1d) \
+ or isinstance(module, torch.nn.ConvTranspose1d) or isinstance(module, torch.nn.Embedding):
+ dump_torch_weights(cwriter, module, name=export_name(name), verbose=True)
+
+ cwriter.close()