Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2018-10-09 16:36:17 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2018-10-09 16:36:17 +0300
commitaed31cd5d5341712819772d27ec99e364f20e8d0 (patch)
tree2dd5976a5a985307e31b0e8072faacaffaff38e9 /scripts
parente53acb01a92f82e1e056bfaf70173cdcc5932989 (diff)
Add script for injecting 'decoder_c_tt'
Diffstat (limited to 'scripts')
-rw-r--r--scripts/contrib/inject_ctt.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/scripts/contrib/inject_ctt.py b/scripts/contrib/inject_ctt.py
new file mode 100644
index 00000000..751ee1c6
--- /dev/null
+++ b/scripts/contrib/inject_ctt.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+
+from __future__ import print_function
+
+import sys
+import argparse
+import numpy as np
+
+DESC = "Add 'decoder_c_tt' required by Amun to a model trained with Marian v1.6.0+"
+
+
+def main():
+ args = parse_args()
+
+ print("Loading model {}".format(args.input))
+ model = np.load(args.input)
+
+ if "decoder_c_tt" in model:
+ print("The model already contains 'decoder_c_tt'")
+ exit()
+
+ print("Adding 'decoder_c_tt' to the model")
+ amun = {"decoder_c_tt": np.zeros((1, 0))}
+ for tensor_name in model:
+ amun[tensor_name] = model[tensor_name]
+
+ print("Saving model...")
+ np.savez(args.output, **amun)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description=DESC)
+ parser.add_argument("-i", "--input", help="input model", required=True)
+ parser.add_argument("-o", "--output", help="output model", required=True)
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ main()