Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/stanfordnlp/stanza.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'stanza/models/charlm.py')
-rw-r--r--stanza/models/charlm.py49
1 files changed, 29 insertions, 20 deletions
diff --git a/stanza/models/charlm.py b/stanza/models/charlm.py
index 6ffc14b8..36345121 100644
--- a/stanza/models/charlm.py
+++ b/stanza/models/charlm.py
@@ -13,6 +13,7 @@ import math
import logging
import time
import os
+import lzma
from stanza.models.common.char_model import CharacterLanguageModel
from stanza.models.common.vocab import CharVocab
@@ -44,33 +45,41 @@ def get_batch(source, i, seq_len):
target = source[:, i+1:i+1+seq_len].reshape(-1)
return data, target
+def readlines(path):
+ if path.endswith(".xz"):
+ with lzma.open(path, mode='rt') as fin:
+ lines = fin.readlines()
+ else:
+ with open(path) as fin:
+ lines = fin.readlines() # preserve '\n'
+ return lines
+
def build_vocab(path, cutoff=0):
# Requires a large amount of memory, but only need to build once
+
+ # here we need some trick to deal with excessively large files
+ # for each file we accumulate the counter of characters, and
+ # at the end we simply pass a list of chars to the vocab builder
+ counter = Counter()
if os.path.isdir(path):
- # here we need some trick to deal with excessively large files
- # for each file we accumulate the counter of characters, and
- # at the end we simply pass a list of chars to the vocab builder
- counter = Counter()
filenames = sorted(os.listdir(path))
- for filename in filenames:
- lines = open(path + '/' + filename).readlines()
- for line in lines:
- counter.update(list(line))
- # remove infrequent characters from vocab
- for k in list(counter.keys()):
- if counter[k] < cutoff:
- del counter[k]
- # a singleton list of all characters
- data = [sorted([x[0] for x in counter.most_common()])]
- vocab = CharVocab(data) # skip cutoff argument because this has been dealt with
else:
- lines = open(path).readlines() # reserve '\n'
- data = [list(line) for line in lines]
- vocab = CharVocab(data, cutoff=cutoff)
+ filenames = [path]
+ for filename in filenames:
+ lines = readlines(path + '/' + filename)
+ for line in lines:
+ counter.update(list(line))
+ # remove infrequent characters from vocab
+ for k in list(counter.keys()):
+ if counter[k] < cutoff:
+ del counter[k]
+ # a singleton list of all characters
+ data = [sorted([x[0] for x in counter.most_common()])]
+ vocab = CharVocab(data) # skip cutoff argument because this has been dealt with
return vocab
def load_file(path, vocab, direction):
- lines = open(path).readlines() # reserve '\n'
+ lines = readlines(path)
data = list(''.join(lines))
idx = vocab['char'].map(data)
if direction == 'backward': idx = idx[::-1]
@@ -90,7 +99,7 @@ def load_data(path, vocab, direction):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--train_file', type=str, help="Input plaintext file")
- parser.add_argument('--train_dir', type=str, help="If non-emtpy, load from directory with multiple training files")
+ parser.add_argument('--train_dir', type=str, help="If non-empty, load from directory with multiple training files")
parser.add_argument('--eval_file', type=str, help="Input plaintext file for the dev/test set")
parser.add_argument('--lang', type=str, help="Language")
parser.add_argument('--shorthand', type=str, help="UD treebank shorthand")