diff options
author | Rico Sennrich <rico.sennrich@gmx.ch> | 2015-05-29 18:07:26 +0300 |
---|---|---|
committer | Rico Sennrich <rico.sennrich@gmx.ch> | 2015-05-29 18:07:26 +0300 |
commit | 5d8af9c2896d86785c5db2fd3a8029ae9b741e26 (patch) | |
tree | b99868426a8c941b995d85a39d9378801e66f6a9 /scripts/training/bilingual-lm/train_nplm.py | |
parent | ef028446f3640e007215b4576a4dc52a9c9de6db (diff) |
support memory-mapped files for NPLM training
Diffstat (limited to 'scripts/training/bilingual-lm/train_nplm.py')
-rwxr-xr-x | scripts/training/bilingual-lm/train_nplm.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/scripts/training/bilingual-lm/train_nplm.py b/scripts/training/bilingual-lm/train_nplm.py index cb5980a91..572076006 100755 --- a/scripts/training/bilingual-lm/train_nplm.py +++ b/scripts/training/bilingual-lm/train_nplm.py @@ -39,7 +39,8 @@ parser.add_argument("--input-words-file", dest="input_words_file") parser.add_argument("--output-words-file", dest="output_words_file") parser.add_argument("--input_vocab_size", dest="input_vocab_size", type=int) parser.add_argument("--output_vocab_size", dest="output_vocab_size", type=int) - +parser.add_argument("--mmap", dest="mmap", action="store_true", + help="Use memory-mapped file (for lower memory consumption).") parser.set_defaults( working_dir="working", @@ -113,6 +114,11 @@ def main(options): options.working_dir, os.path.basename(options.corpus_stem) + ".numberized") + mmap_command = [] + if options.mmap: + in_file += '.mmap' + mmap_command = ['--mmap_file', '1'] + model_prefix = os.path.join( options.output_dir, options.output_model + ".model.nplm") train_args = [ @@ -127,9 +133,9 @@ def main(options): "--input_embedding_dimension", str(options.input_embedding), "--output_embedding_dimension", str(options.output_embedding), "--num_threads", str(options.threads), - "--activation_function", - options.activation_fn, - ] + validations_command + vocab_command + "--activation_function", options.activation_fn, + "--ngram_size", str(options.ngram_size), + ] + validations_command + vocab_command + mmap_command print("Train model command: ") print(', '.join(train_args)) |