diff options
author | Jan Buethe <jbuethe@amazon.de> | 2023-12-07 17:26:33 +0300 |
---|---|---|
committer | Jan Buethe <jbuethe@amazon.de> | 2023-12-07 17:26:33 +0300 |
commit | 0c4417aa0100e6fe140cff97c408fc4e6428ffae (patch) | |
tree | 96cd7e71ece378cbd7cff13a2ec4a7381d4b917c | |
parent | b286c504971993cfd37dc54cd5ed9d38fb3938ac (diff) |
added option for maximal quantization to osce/export_model_weigts.py
-rw-r--r-- | dnn/torch/osce/export_model_weights.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/dnn/torch/osce/export_model_weights.py b/dnn/torch/osce/export_model_weights.py index c3b723c7..786d3200 100644 --- a/dnn/torch/osce/export_model_weights.py +++ b/dnn/torch/osce/export_model_weights.py @@ -51,6 +51,7 @@ parser = argparse.ArgumentParser() parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint') parser.add_argument('output_dir', type=str, help='output folder') +parser.add_argument('--quantize', action="store_true", help='quantization according to schedule') schedules = { @@ -60,15 +61,15 @@ schedules = { ('feature_net.conv2', dict(quantize=True, scale=None)), ('feature_net.tconv', dict(quantize=True, scale=None)), ('feature_net.gru', dict()), - ('cf1', dict()), - ('cf2', dict()), - ('af1', dict()), + ('cf1', dict(quantize=True, scale=None)), + ('cf2', dict(quantize=True, scale=None)), + ('af1', dict(quantize=True, scale=None)), ('tdshape1', dict()), ('tdshape2', dict()), ('tdshape3', dict()), - ('af2', dict()), - ('af3', dict()), - ('af4', dict()), + ('af2', dict(quantize=True, scale=None)), + ('af3', dict(quantize=True, scale=None)), + ('af4', dict(quantize=True, scale=None)), ('post_cf1', dict(quantize=True, scale=None)), ('post_cf2', dict(quantize=True, scale=None)), ('post_af1', dict(quantize=True, scale=None)), @@ -81,9 +82,9 @@ schedules = { ('feature_net.conv2', dict(quantize=True, scale=None)), ('feature_net.tconv', dict(quantize=True, scale=None)), ('feature_net.gru', dict()), - ('cf1', dict()), - ('cf2', dict()), - ('af1', dict()) + ('cf1', dict(quantize=True, scale=None)), + ('cf2', dict(quantize=True, scale=None)), + ('af1', dict(quantize=True, scale=None)) ] } @@ -161,7 +162,7 @@ if __name__ == "__main__": cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}\n") # dump layers - if model_name in schedules: + if model_name in schedules and args.quantize: osce_scheduled_dump(cwriter, model_name, model, schedules[model_name]) else: osce_dump_generic(cwriter, model_name, model) |