diff options
author | John Bauer <horatio@gmail.com> | 2022-11-07 23:05:35 +0300 |
---|---|---|
committer | John Bauer <horatio@gmail.com> | 2022-11-07 23:05:35 +0300 |
commit | a07cf172e2ba3c8c82b4195dec9d72364ad4c5e0 (patch) | |
tree | 98ca8942d1093d615834e38e17b2fd673795c079 | |
parent | 3a06cead4d81c5b35060c06895c2206113aa37cc (diff) |
Change convert_pretrain to use argparse so it has a nice --help method
-rw-r--r-- | stanza/models/common/convert_pretrain.py | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/stanza/models/common/convert_pretrain.py b/stanza/models/common/convert_pretrain.py index 011638f2..45ddbb34 100644 --- a/stanza/models/common/convert_pretrain.py +++ b/stanza/models/common/convert_pretrain.py @@ -16,25 +16,26 @@ Note that if the pretrain already exists, nothing will be changed. It will not """ +import argparse import os import sys from stanza.models.common import pretrain def main(): - filename = sys.argv[1] - if os.path.exists(filename): - print("Not overwriting existing pretrain file in %s" % filename) - vec_filename = sys.argv[2] - if len(sys.argv) < 3: - max_vocab = -1 - else: - max_vocab = int(sys.argv[3]) + parser = argparse.ArgumentParser() + parser.add_argument("output_pt", default=None, help="Where to write the converted PT file") + parser.add_argument("input_vec", default=None, help="Unconverted vectors file") + parser.add_argument("max_vocab", default=-1, nargs="?", help="How many vectors to convert. -1 means convert them all") + args = parser.parse_args() + + if os.path.exists(args.output_pt): + print("Not overwriting existing pretrain file in %s" % args.output_pt) - if vec_filename.endswith(".csv"): - pt = pretrain.Pretrain(filename, max_vocab=max_vocab, csv_filename=vec_filename) + if args.input_vec.endswith(".csv"): + pt = pretrain.Pretrain(args.output_pt, max_vocab=args.max_vocab, csv_filename=args.input_vec) else: - pt = pretrain.Pretrain(filename, vec_filename, max_vocab=max_vocab) + pt = pretrain.Pretrain(args.output_pt, args.input_vec, max_vocab=args.max_vocab) print("Pretrain is of size {}".format(len(pt.vocab))) if __name__ == '__main__': |