Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-07-04 14:43:31 +0300
committerRoman Grundkiewicz <rgrundki@exseed.ed.ac.uk>2017-07-04 14:43:31 +0300
commit32f604a67fc72929cb13ea6aa883921a505ed960 (patch)
treeb2f0751060e53c59c3e21260ef23929903144164 /scripts
parentda15fd5717205174ba7d7a49bb4fe45f9533f6ab (diff)
Add --dim-voc option to process_word2vec.py
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/embeddings/process_word2vec.py29
1 files changed, 22 insertions, 7 deletions
diff --git a/scripts/embeddings/process_word2vec.py b/scripts/embeddings/process_word2vec.py
index 5be9adf7..4f5ba493 100755
--- a/scripts/embeddings/process_word2vec.py
+++ b/scripts/embeddings/process_word2vec.py
@@ -35,6 +35,10 @@ def main():
print(" lines: {}".format(lines))
print(" entries: {}".format(len(vocab)))
+ if args.dim_voc is not None:
+ vocab = {w: v for w, v in vocab.items() if v < args.dim_voc}
+ print(" loaded: {}".format(len(vocab)))
+
if args.word2vec:
print("Adding <unk> and </s> tokens to the corpus")
prep_corpus = args.input + '.prep'
@@ -56,21 +60,30 @@ def main():
orig_vectors = args.input
print("Replacing words with IDs in vector file")
- n = 1
- with open(orig_vectors) as cin, open(args.output, 'w+') as cout:
+ n = 0
+ embs = []
+ dim_emb = None
+ with open(orig_vectors) as cin:
for i, line in enumerate(cin):
if i == 0:
- cout.write(line)
+ dim_emb = line.strip().split(' ')[-1]
continue
word, tail = line.split(' ', 1)
if word in vocab:
- cout.write("{} {}".format(vocab[word], tail))
+ embs.append("{} {}".format(vocab[word], tail))
n += 1
else:
- print(" warning: no word '{}' in vocabulary, line {}".format(
- word, i + 1))
+ if not args.quiet:
+ print(" no word '{}' in vocabulary, line {}".format(
+ word, i + 1))
+
print(" words: {}".format(n))
+ with open(args.output, 'w') as cout:
+ cout.write("{} {}\n".format(n, dim_emb))
+ for emb in embs:
+ cout.write(emb)
+
print("Finished")
@@ -103,8 +116,10 @@ embedding vectors with regard to the word vocabulary."""
parser.add_argument("-o", "--output", help="output embedding vectors", required=True)
parser.add_argument("-v", "--vocab", help="path to vocabulary in JSON or YAML format", required=True)
parser.add_argument("-w", "--word2vec", help="path to word2vec, assumes text corpus on input")
- parser.add_argument("-d", "--dim-emb", help="size of embedding vector, only for training", default=512)
+ parser.add_argument("--dim-emb", help="size of embedding vector, only for training", default=512)
+ parser.add_argument("--dim-voc", help= "maximum number of words from vocabulary to be used, default: no limit", type=int)
parser.add_argument("-t", "--threads", help="number of threads", default=16)
+ parser.add_argument("--quiet", help="skip printing warnings", action='store_true')
return parser.parse_args()