diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-11-23 15:56:47 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-11-23 15:56:47 +0300 |
commit | 75ab4c2c337263c35202506fa55c5d09d35e8779 (patch) | |
tree | eb093366dace0bfbb322460607d513df026001bd | |
parent | e097437e70ed4cfee8cd53d6a06ba9040f802f04 (diff) |
bugfix
-rw-r--r-- | dnn/torch/osce/export_model_weights.py | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py index 8e543828..4b3c52bd 100644 --- a/dnn/torch/osce/export_model_weights.py +++ b/dnn/torch/osce/export_model_weights.py @@ -90,17 +90,14 @@ if __name__ == "__main__": # create model and load weights checkpoint = torch.load(checkpoint_path, map_location='cpu') model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs']) + model.load_state_dict(checkpoint['state_dict']) # CWriter model_name = checkpoint['setup']['model']['name'] cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper()) - # dump numbits_embedding parameters by hand - numbits_embedding = model.get_submodule('numbits_embedding') - weights = next(iter(numbits_embedding.parameters())) - for i, c in enumerate(weights): - cwriter.header.write(f"\nNUMBITS_COEF_{i} {float(c.detach())}f") - cwriter.header.write("\n\n") + # Add custom includes + cwriter.header.write('\n#include "osce.h"\n') # dump layers osce_dump_generic(cwriter, model_name, model) |