diff options
author | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2018-10-09 16:36:17 +0300 |
---|---|---|
committer | Roman Grundkiewicz <rgrundki@exseed.ed.ac.uk> | 2018-10-09 16:36:17 +0300 |
commit | aed31cd5d5341712819772d27ec99e364f20e8d0 (patch) | |
tree | 2dd5976a5a985307e31b0e8072faacaffaff38e9 /scripts | |
parent | e53acb01a92f82e1e056bfaf70173cdcc5932989 (diff) |
Add script for injecting 'decoder_c_tt'
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/contrib/inject_ctt.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/scripts/contrib/inject_ctt.py b/scripts/contrib/inject_ctt.py new file mode 100644 index 00000000..751ee1c6 --- /dev/null +++ b/scripts/contrib/inject_ctt.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +from __future__ import print_function + +import sys +import argparse +import numpy as np + +DESC = "Add 'decoder_c_tt' required by Amun to a model trained with Marian v1.6.0+" + + +def main(): + args = parse_args() + + print("Loading model {}".format(args.input)) + model = np.load(args.input) + + if "decoder_c_tt" in model: + print("The model already contains 'decoder_c_tt'") + exit() + + print("Adding 'decoder_c_tt' to the model") + amun = {"decoder_c_tt": np.zeros((1, 0))} + for tensor_name in model: + amun[tensor_name] = model[tensor_name] + + print("Saving model...") + np.savez(args.output, **amun) + + +def parse_args(): + parser = argparse.ArgumentParser(description=DESC) + parser.add_argument("-i", "--input", help="input model", required=True) + parser.add_argument("-o", "--output", help="output model", required=True) + return parser.parse_args() + + +if __name__ == "__main__": + main() |