Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.xiph.org/xiph/opus.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Buethe <jbuethe@amazon.de>2023-12-07 17:26:33 +0300
committerJan Buethe <jbuethe@amazon.de>2023-12-07 17:26:33 +0300
commit0c4417aa0100e6fe140cff97c408fc4e6428ffae (patch)
tree96cd7e71ece378cbd7cff13a2ec4a7381d4b917c
parentb286c504971993cfd37dc54cd5ed9d38fb3938ac (diff)
added option for maximal quantization to osce/export_model_weigts.py
-rw-r--r--dnn/torch/osce/export_model_weights.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py
index c3b723c7..786d3200 100644
--- a/dnn/torch/osce/export_model_weights.py
+++ b/dnn/torch/osce/export_model_weights.py
@@ -51,6 +51,7 @@ parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint')
parser.add_argument('output_dir', type=str, help='output folder')
+parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
schedules = {
@@ -60,15 +61,15 @@ schedules = {
('feature_net.conv2', dict(quantize=True, scale=None)),
('feature_net.tconv', dict(quantize=True, scale=None)),
('feature_net.gru', dict()),
- ('cf1', dict()),
- ('cf2', dict()),
- ('af1', dict()),
+ ('cf1', dict(quantize=True, scale=None)),
+ ('cf2', dict(quantize=True, scale=None)),
+ ('af1', dict(quantize=True, scale=None)),
('tdshape1', dict()),
('tdshape2', dict()),
('tdshape3', dict()),
- ('af2', dict()),
- ('af3', dict()),
- ('af4', dict()),
+ ('af2', dict(quantize=True, scale=None)),
+ ('af3', dict(quantize=True, scale=None)),
+ ('af4', dict(quantize=True, scale=None)),
('post_cf1', dict(quantize=True, scale=None)),
('post_cf2', dict(quantize=True, scale=None)),
('post_af1', dict(quantize=True, scale=None)),
@@ -81,9 +82,9 @@ schedules = {
('feature_net.conv2', dict(quantize=True, scale=None)),
('feature_net.tconv', dict(quantize=True, scale=None)),
('feature_net.gru', dict()),
- ('cf1', dict()),
- ('cf2', dict()),
- ('af1', dict())
+ ('cf1', dict(quantize=True, scale=None)),
+ ('cf2', dict(quantize=True, scale=None)),
+ ('af1', dict(quantize=True, scale=None))
]
}
@@ -161,7 +162,7 @@ if __name__ == "__main__":
cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}\n")
# dump layers
- if model_name in schedules:
+ if model_name in schedules and args.quantize:
osce_scheduled_dump(cwriter, model_name, model, schedules[model_name])
else:
osce_dump_generic(cwriter, model_name, model)