Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-11-23 15:56:47 +0300
committerJan Buethe <jbuethe@amazon.de>2023-11-23 15:56:47 +0300
commit75ab4c2c337263c35202506fa55c5d09d35e8779 (patch)
treeeb093366dace0bfbb322460607d513df026001bd
parente097437e70ed4cfee8cd53d6a06ba9040f802f04 (diff)
bugfix
-rw-r--r--dnn/torch/osce/export_model_weights.py9
1 files changed, 3 insertions, 6 deletions
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py
index 8e543828..4b3c52bd 100644
--- a/dnn/torch/osce/export_model_weights.py
+++ b/dnn/torch/osce/export_model_weights.py
@@ -90,17 +90,14 @@ if __name__ == "__main__":
# create model and load weights
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
+ model.load_state_dict(checkpoint['state_dict'])
# CWriter
model_name = checkpoint['setup']['model']['name']
cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper())
- # dump numbits_embedding parameters by hand
- numbits_embedding = model.get_submodule('numbits_embedding')
- weights = next(iter(numbits_embedding.parameters()))
- for i, c in enumerate(weights):
- cwriter.header.write(f"\nNUMBITS_COEF_{i} {float(c.detach())}f")
- cwriter.header.write("\n\n")
+ # Add custom includes
+ cwriter.header.write('\n#include "osce.h"\n')
# dump layers
osce_dump_generic(cwriter, model_name, model)