diff options
author | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2018-01-13 14:12:27 +0300 |
---|---|---|
committer | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2018-01-13 14:12:27 +0300 |
commit | ba0ea7491fab383992013a8379592657eedfe1ce (patch) | |
tree | 31e76ebac94526158e77552ef039283a5b8b2342 /scripts/contrib | |
parent | e6970cb5d5711997f689709396e1ef3ae9c74ac4 (diff) |
Add printing value for any key from model.npz
Diffstat (limited to 'scripts/contrib')
-rw-r--r-- | scripts/contrib/model_info.py | 48 |
1 files changed, 30 insertions, 18 deletions
diff --git a/scripts/contrib/model_info.py b/scripts/contrib/model_info.py index a08afcc9..4d14d785 100644 --- a/scripts/contrib/model_info.py +++ b/scripts/contrib/model_info.py @@ -6,37 +6,49 @@ import numpy as np import yaml -DESC = "Prints version and model type from model.npz file." +DESC = "Prints keys and values from model.npz file." S2S_SPECIAL_NODE = "special:model.yml" def main(): args = parse_args() - model = np.load(args.model) - if S2S_SPECIAL_NODE not in model: - print("No special Marian YAML node found in the model") - exit(1) - - yaml_text = bytes(model[S2S_SPECIAL_NODE]).decode('ascii') - if not args.key: - print(yaml_text) - exit(0) - - # fix the invalid trailing unicode character '#x0000' added to the YAML - # string by the C++ cnpy library - try: - yaml_node = yaml.load(yaml_text) - except yaml.reader.ReaderError: - yaml_node = yaml.load(yaml_text[:-1]) - print(yaml_node[args.key]) + if args.special: + if S2S_SPECIAL_NODE not in model: + print("No special Marian YAML node found in the model") + exit(1) + + yaml_text = bytes(model[S2S_SPECIAL_NODE]).decode('ascii') + if not args.key: + print(yaml_text) + exit(0) + + # fix the invalid trailing unicode character '#x0000' added to the YAML + # string by the C++ cnpy library + try: + yaml_node = yaml.load(yaml_text) + except yaml.reader.ReaderError: + yaml_node = yaml.load(yaml_text[:-1]) + + print(yaml_node[args.key]) + else: + if args.key: + if args.key not in model: + print("Key not found") + exit(1) + print(model[args.key]) + else: + for key in model: + print(key) def parse_args(): parser = argparse.ArgumentParser(description=DESC) parser.add_argument("-m", "--model", help="model file", required=True) parser.add_argument("-k", "--key", help="print value for specific key") + parser.add_argument("-s", "--special", action="store_true", + help="print values from special:model.yml node") return parser.parse_args() |