Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Bauer <horatio@gmail.com>2022-11-07 23:05:35 +0300
committerJohn Bauer <horatio@gmail.com>2022-11-07 23:05:35 +0300
commita07cf172e2ba3c8c82b4195dec9d72364ad4c5e0 (patch)
tree98ca8942d1093d615834e38e17b2fd673795c079
parent3a06cead4d81c5b35060c06895c2206113aa37cc (diff)
Change convert_pretrain to use argparse so it has a nice --help method
-rw-r--r--stanza/models/common/convert_pretrain.py23
1 files changed, 12 insertions, 11 deletions
diff --git a/stanza/models/common/convert_pretrain.py b/stanza/models/common/convert_pretrain.py
index 011638f2..45ddbb34 100644
--- a/stanza/models/common/convert_pretrain.py
+++ b/stanza/models/common/convert_pretrain.py
@@ -16,25 +16,26 @@ Note that if the pretrain already exists, nothing will be changed. It will not
"""
+import argparse
import os
import sys
from stanza.models.common import pretrain
def main():
- filename = sys.argv[1]
- if os.path.exists(filename):
- print("Not overwriting existing pretrain file in %s" % filename)
- vec_filename = sys.argv[2]
- if len(sys.argv) < 3:
- max_vocab = -1
- else:
- max_vocab = int(sys.argv[3])
+ parser = argparse.ArgumentParser()
+ parser.add_argument("output_pt", default=None, help="Where to write the converted PT file")
+ parser.add_argument("input_vec", default=None, help="Unconverted vectors file")
+ parser.add_argument("max_vocab", default=-1, nargs="?", help="How many vectors to convert. -1 means convert them all")
+ args = parser.parse_args()
+
+ if os.path.exists(args.output_pt):
+ print("Not overwriting existing pretrain file in %s" % args.output_pt)
- if vec_filename.endswith(".csv"):
- pt = pretrain.Pretrain(filename, max_vocab=max_vocab, csv_filename=vec_filename)
+ if args.input_vec.endswith(".csv"):
+ pt = pretrain.Pretrain(args.output_pt, max_vocab=args.max_vocab, csv_filename=args.input_vec)
else:
- pt = pretrain.Pretrain(filename, vec_filename, max_vocab=max_vocab)
+ pt = pretrain.Pretrain(args.output_pt, args.input_vec, max_vocab=args.max_vocab)
print("Pretrain is of size {}".format(len(pt.vocab)))
if __name__ == '__main__':