Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2018-01-13 14:12:27 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2018-01-13 14:12:27 +0300
commitba0ea7491fab383992013a8379592657eedfe1ce (patch)
tree31e76ebac94526158e77552ef039283a5b8b2342 /scripts/contrib
parente6970cb5d5711997f689709396e1ef3ae9c74ac4 (diff)
Add printing value for any key from model.npz
Diffstat (limited to 'scripts/contrib')
-rw-r--r--scripts/contrib/model_info.py48
1 files changed, 30 insertions, 18 deletions
diff --git a/scripts/contrib/model_info.py b/scripts/contrib/model_info.py
index a08afcc9..4d14d785 100644
--- a/scripts/contrib/model_info.py
+++ b/scripts/contrib/model_info.py
@@ -6,37 +6,49 @@ import numpy as np
import yaml
-DESC = "Prints version and model type from model.npz file."
+DESC = "Prints keys and values from model.npz file."
S2S_SPECIAL_NODE = "special:model.yml"
def main():
args = parse_args()
-
model = np.load(args.model)
- if S2S_SPECIAL_NODE not in model:
- print("No special Marian YAML node found in the model")
- exit(1)
-
- yaml_text = bytes(model[S2S_SPECIAL_NODE]).decode('ascii')
- if not args.key:
- print(yaml_text)
- exit(0)
-
- # fix the invalid trailing unicode character '#x0000' added to the YAML
- # string by the C++ cnpy library
- try:
- yaml_node = yaml.load(yaml_text)
- except yaml.reader.ReaderError:
- yaml_node = yaml.load(yaml_text[:-1])
- print(yaml_node[args.key])
+ if args.special:
+ if S2S_SPECIAL_NODE not in model:
+ print("No special Marian YAML node found in the model")
+ exit(1)
+
+ yaml_text = bytes(model[S2S_SPECIAL_NODE]).decode('ascii')
+ if not args.key:
+ print(yaml_text)
+ exit(0)
+
+ # fix the invalid trailing unicode character '#x0000' added to the YAML
+ # string by the C++ cnpy library
+ try:
+ yaml_node = yaml.load(yaml_text)
+ except yaml.reader.ReaderError:
+ yaml_node = yaml.load(yaml_text[:-1])
+
+ print(yaml_node[args.key])
+ else:
+ if args.key:
+ if args.key not in model:
+ print("Key not found")
+ exit(1)
+ print(model[args.key])
+ else:
+ for key in model:
+ print(key)
def parse_args():
parser = argparse.ArgumentParser(description=DESC)
parser.add_argument("-m", "--model", help="model file", required=True)
parser.add_argument("-k", "--key", help="print value for specific key")
+ parser.add_argument("-s", "--special", action="store_true",
+ help="print values from special:model.yml node")
return parser.parse_args()