diff options
Diffstat (limited to 'dnn/torch/lpcnet/print_lpcnet_complexity.py')
-rw-r--r-- | dnn/torch/lpcnet/print_lpcnet_complexity.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/dnn/torch/lpcnet/print_lpcnet_complexity.py b/dnn/torch/lpcnet/print_lpcnet_complexity.py new file mode 100644 index 00000000..a47352be --- /dev/null +++ b/dnn/torch/lpcnet/print_lpcnet_complexity.py @@ -0,0 +1,35 @@ +import argparse + +import yaml + +from models import model_dict + + +debug = False +if debug: + args = type('dummy', (object,), + { + 'setup' : 'setups/lpcnet_m/setup_1_4_concatenative.yml', + 'hierarchical_sampling' : False + })() +else: + parser = argparse.ArgumentParser() + parser.add_argument('setup', type=str, help='setup yaml file') + parser.add_argument('--hierarchical-sampling', action="store_true", help='whether to assume hierarchical sampling (default=False)', default=False) + + args = parser.parse_args() + +with open(args.setup, 'r') as f: + setup = yaml.load(f.read(), yaml.FullLoader) + +# check model +if not 'model' in setup['lpcnet']: + print(f'warning: did not find model entry in setup, using default lpcnet') + model_name = 'lpcnet' +else: + model_name = setup['lpcnet']['model'] + +# create model +model = model_dict[model_name](setup['lpcnet']['config']) + +gflops = model.get_gflops(16000, verbose=True, hierarchical_sampling=args.hierarchical_sampling) |